-
Notifications
You must be signed in to change notification settings - Fork 1.4k
add sigmoid/softmax support and multi-class extension for AsymmetricU… #8642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
…nifiedFocalLoss Signed-off-by: ytl0623 <[email protected]>
WalkthroughAdds a Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (1)
243-244: Redundant shape check.Both
AsymmetricFocalLossandAsymmetricFocalTverskyLossalready validate shapes. This check in the parent is redundant whento_onehot_y=False. Whento_onehot_y=True,y_trueshape changes after one-hot encoding in sub-losses, making this pre-check potentially incorrect.Consider removing this check or adjusting it to account for the
to_onehot_ycase:- if y_pred.shape != y_true.shape: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") + if not self.to_onehot_y and y_pred.shape != y_true.shape: + raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")tests/losses/test_unified_focal_loss.py (1)
22-50: Test coverage looks reasonable but could be expanded.Cases cover:
- Binary segmentation with default params
- Binary segmentation with explicit sigmoid
- Multi-class with softmax + one-hot
Missing coverage for edge cases:
use_softmax=Truewith binary (single foreground channel)- Multi-class with
use_softmax=False(sigmoid)- Non-zero expected loss values to verify magnitude correctness
Consider adding at least one test case with a non-zero expected loss to verify the loss magnitude calculation:
# Case 3: Verify non-trivial loss value [ {}, { "y_pred": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]], [[[0.0, 0.0], [0.0, 0.0]]]]), "y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), }, # expected_val: compute expected non-zero value ],
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/unified_focal_loss.py(11 hunks)tests/losses/test_unified_focal_loss.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_unified_focal_loss.pymonai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(179-257)
monai/losses/unified_focal_loss.py (1)
monai/utils/enums.py (1)
LossReduction(253-264)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: build-docs
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: packaging
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (macOS-latest)
🔇 Additional comments (5)
monai/losses/unified_focal_loss.py (3)
17-17: LGTM!Import addition for
torch.nn.functionalis necessary for the newlog_softmaxandlogsigmoidoperations.
143-148: LGTM!Good numerical stability practice using
log_softmax/logsigmoidinstead of computinglog(softmax(...)).
170-175: Verify intentional asymmetry between background and foreground aggregation.
back_ceis a single channel value whilefore_cesums across all foreground channels before stacking. This means multi-class scenarios weight foreground loss by number of classes. Confirm this matches the paper's intended behavior.tests/losses/test_unified_focal_loss.py (2)
55-59: LGTM!Parameterized test properly accepts constructor arguments and validates loss output.
66-75: LGTM!CUDA test properly validates loss computation on GPU with updated input values.
Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)
36-44:reductionparameter accepted but ignored.
AsymmetricFocalTverskyLoss.__init__acceptsreduction(line 43) and stores it (line 54), butforward()always returnstorch.mean()(line 103) regardless of the reduction setting. Same issue inAsymmetricFocalLoss. Either implement reduction in these sub-losses or remove the parameter.def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: # ... existing logic ... - loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) - return loss + loss = torch.stack([back_dice, fore_dice], dim=-1) + if self.reduction == LossReduction.SUM.value: + return torch.sum(loss) + if self.reduction == LossReduction.NONE.value: + return loss + if self.reduction == LossReduction.MEAN.value: + return torch.mean(loss) + raise ValueError(f'Unsupported reduction: {self.reduction}')Also applies to: 103-103
♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)
134-134: Docstring default value incorrect.States gamma defaults to 0.75, but actual default is 2.
- gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py(11 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (1)
monai/utils/enums.py (1)
LossReduction(253-264)
🔇 Additional comments (4)
monai/losses/unified_focal_loss.py (4)
17-17: LGTM—Import required for log probability functions.
28-66: LGTM—Activation logic and parameter addition correct.Typo fix and multi-class support properly documented. Applying activation before validation is correct since inputs are expected to be logits.
145-151: LGTM—Log probability handling is numerically stable.Using
log_softmax/logsigmoidwith separate log and probability tensors avoids numerical issues.
225-230: LGTM—Sub-losses properly configured with use_softmax.Parameters correctly forwarded to both sub-losses.
Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)
180-183: Aggregation inconsistency persists.
torch.sumhere vstorch.meanat line 99 inAsymmetricFocalTverskyLoss. Loss magnitude scales differently with class count.
🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)
67-77: Warning check is unreachable after single-channel expansion.After lines 67-69 expand single-channel to two-class,
n_pred_chat line 71 will always be ≥2. The warning at lines 74-75 becomes dead code for originally single-channel inputs.Consider moving the warning check before the expansion, or remove it since the expansion handles single-channel inputs.
+ n_pred_ch = y_pred.shape[1] + if self.to_onehot_y and n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + if y_pred.shape[1] == 1: y_pred = torch.cat([1 - y_pred, y_pred], dim=1) y_true = torch.cat([1 - y_true, y_true], dim=1) - n_pred_ch = y_pred.shape[1] - - if self.to_onehot_y: - if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - else: - y_true = one_hot(y_true, num_classes=n_pred_ch) + if self.to_onehot_y and y_pred.shape[1] > 1: + y_true = one_hot(y_true, num_classes=y_pred.shape[1])
254-260: Reduction logic has no effect.Sub-losses at lines 249-250 return scalars (they use their own reduction). The outer reduction here (lines 254-260) on a scalar is effectively a no-op. Consider passing
reductionto sub-losses or documenting this behavior.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py(10 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.9)
🔇 Additional comments (1)
monai/losses/unified_focal_loss.py (1)
233-238: LGTM!Sub-losses correctly wired with shared parameters.
Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (3)
monai/losses/unified_focal_loss.py (3)
36-80: Fix single‑channel +to_onehot_yordering and clamp foreground dice base.Two points here:
Bug:
to_onehot_yafter single‑channel expansion.
Inforward, you expand 1‑channel predictions/targets before computingn_pred_chand applyingto_onehot_y:
- For binary cases with
to_onehot_y=True, this now callsone_hot()on a 2‑channel float tensor (y_true = cat([1 - y_true, y_true], ...)) instead of the original label map, and it also defeats the previous “single channel prediction,to_onehot_y=Trueignored.” behaviour.Reorder so
n_pred_chandto_onehot_yare evaluated on the original channel count, then expand to two channels:
if self.use_softmax:y_pred = torch.softmax(y_pred, dim=1)else:y_pred = torch.sigmoid(y_pred)if y_pred.shape[1] == 1:y_pred = torch.cat([1 - y_pred, y_pred], dim=1)y_true = torch.cat([1 - y_true, y_true], dim=1)n_pred_ch = y_pred.shape[1]if self.to_onehot_y:if n_pred_ch == 1:warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")else:y_true = one_hot(y_true, num_classes=n_pred_ch)
if self.use_softmax:y_pred = torch.softmax(y_pred, dim=1)else:y_pred = torch.sigmoid(y_pred)n_pred_ch = y_pred.shape[1]if self.to_onehot_y:if n_pred_ch == 1:warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")else:y_true = one_hot(y_true, num_classes=n_pred_ch)if y_pred.shape[1] == 1:y_pred = torch.cat([1 - y_pred, y_pred], dim=1)y_true = torch.cat([1 - y_true, y_true], dim=1)
Numeric stability: clamp base before power.
fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma)can still explode if a user passesgamma > 1anddice_class[:, 1:] → 1(negative exponent on values near zero). To guard this, clamp the base:
fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma)
fore_dice_base = torch.clamp(1 - dice_class[:, 1:], min=self.epsilon)fore_dice = torch.pow(fore_dice_base, 1 - self.gamma)Also applies to: 84-99
118-185: Reorderto_onehot_y, fix single‑channel log‑prob stability, and consider foreground aggregation.This block has three intertwined issues:
Bug:
to_onehot_yafter single‑channel expansion (same as Tversky).
Expandingy_pred/y_trueto two channels before computingn_pred_chand applyingto_onehot_ymeans that for binary cases withto_onehot_y=Trueyou now callone_hot()on an already expanded 2‑channel float tensor instead of the original label map, and you lose the prior “ignore + warn” semantics.Reorder to evaluate
to_onehot_yon the original channel count, then expand:
- def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
if self.use_softmax:y_log_pred = F.log_softmax(y_pred, dim=1)y_pred = torch.exp(y_log_pred)else:y_log_pred = F.logsigmoid(y_pred)y_pred = torch.sigmoid(y_pred)if y_pred.shape[1] == 1:y_pred = torch.cat([1 - y_pred, y_pred], dim=1)y_log_pred = torch.log(torch.clamp(y_pred, 1e-7, 1.0))y_true = torch.cat([1 - y_true, y_true], dim=1)n_pred_ch = y_pred.shape[1]if self.to_onehot_y:if n_pred_ch == 1:warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")else:y_true = one_hot(y_true, num_classes=n_pred_ch)
- def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
logits = y_predif self.use_softmax:y_log_pred = F.log_softmax(logits, dim=1)y_pred = torch.exp(y_log_pred)else:y_log_pred = F.logsigmoid(logits)y_pred = torch.sigmoid(logits)n_pred_ch = y_pred.shape[1]if self.to_onehot_y:if n_pred_ch == 1:warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")else:y_true = one_hot(y_true, num_classes=n_pred_ch)if y_pred.shape[1] == 1:if self.use_softmax:# Softmax over a single channel is degenerate; fall back to prob-based log.y_pred = torch.cat([1 - y_pred, y_pred], dim=1)y_log_pred = torch.log(torch.clamp(y_pred, self.epsilon, 1.0))else:pos_prob = y_predneg_prob = 1 - pos_probpos_log_prob = y_log_predneg_log_prob = F.logsigmoid(-logits)y_pred = torch.cat([neg_prob, pos_prob], dim=1)y_log_pred = torch.cat([neg_log_prob, pos_log_prob], dim=1)y_true = torch.cat([1 - y_true, y_true], dim=1)
Numeric stability: don’t re‑log probabilities.
The change above also addresses the prior issue where you recomputedy_log_pred = log(clamp(y_pred))in the 1‑channel path, discarding the stablelog_softmax/logsigmoidcomputed from logits. For the binary sigmoid case we now derivelog pandlog(1 - p)directly from logits usingF.logsigmoid, which is the numerically stable form.Foreground CE aggregation vs Tversky (optional but recommended).
Here foreground CE aggregates withsumover foreground classes:fore_ce = cross_entropy[:, 1:] fore_ce = self.delta * fore_ce if fore_ce.shape[1] > 1: fore_ce = torch.sum(fore_ce, dim=1)while
AsymmetricFocalTverskyLossusesmeanacross foreground classes. This keeps the unified loss scale tied to the number of foreground classes. If you want scale invariance w.r.t. class count and consistency with Tversky, switching tomeanis cleaner:
if fore_ce.shape[1] > 1:fore_ce = torch.sum(fore_ce, dim=1)
if fore_ce.shape[1] > 1:fore_ce = torch.mean(fore_ce, dim=1)If the `sum` behaviour is intentional, consider documenting the choice and adding a unit test that asserts the scaling.
211-216: Clarifyy_trueshape forto_onehot_y=Truein docstring.The
forwarddocstring still saysy_trueis alwaysBNH[WD], but withto_onehot_y=Trueyou expect label maps that are converted to one‑hot internally. This matches an earlier comment and is still unresolved.Consider updating as:
Args: y_pred : the shape should be BNH[WD], where N is the number of classes. The input should be the original logits since it will be transformed by - a sigmoid/softmax in the forward function. - y_true : the shape should be BNH[WD], where N is the number of classes. + a sigmoid/softmax in the forward function. + y_true : the shape should be BNH[WD] (one-hot format), or B1H[WD] with class + indices when ``to_onehot_y=True``.Also applies to: 241-246
🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)
47-53: Tighten epsilon docstrings to state purpose.The epsilon descriptions are grammatically off and don’t say what epsilon is for. A more direct wording would help:
- epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. + epsilon : small constant for numerical stability (e.g., to avoid division by + zero or log(0)). Defaults to 1e-7.Apply similarly to both
AsymmetricFocalTverskyLossandAsymmetricFocalLoss.Also applies to: 129-136
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py(9 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (1)
monai/utils/enums.py (1)
LossReduction(253-264)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: packaging
- GitHub Check: build-docs
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (codeformat)
🔇 Additional comments (1)
monai/losses/unified_focal_loss.py (1)
200-237: Constructor wiring foruse_softmaxlooks correct; ensure tests cover both modes.The new
use_softmaxflag is correctly threaded fromAsymmetricUnifiedFocalLossintoAsymmetricFocalLossandAsymmetricFocalTverskyLossalong withto_onehot_y,gamma, anddelta. That matches the intended API.Please make sure tests exercise at least:
- binary and multi-class cases with
use_softmax=True(logits + one-hot/label inputs),- binary sigmoid cases with
use_softmax=False(including single-channel outputs),to_onehot_ybothTrueandFalsein combination with the above.
Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)
61-80: Fix two critical issues with single-channel handling and to_onehot_y semantics in bothAsymmetricFocalTverskyLossandAsymmetricFocalLoss.Verification confirms both issues:
Issue 1 verified: When y_pred.shape[1]==1 with to_onehot_y=True,
one_hot()is called on an already 2-channel expanded y_true (containing [1-y_true, y_true]). This violates the one_hot contract—it expects scalar class indices, not probabilities. Test cases do not cover this scenario.Issue 2 verified: When use_softmax=True with single-channel input, softmax(x, dim=1) returns all 1.0, then concatenation produces exactly [0, 1] everywhere, completely discarding input logits. Test cases do not cover this either.
Apply the suggested refactoring pattern to both classes:
- Perform
to_onehot_yconversion before expanding channels- Explicitly disallow
use_softmax=Truewith single-channel predictions (raise ValueError or strong warning)- Add tests for both edge cases post-fix
♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)
118-135: Single‑channelAsymmetricFocalLossrecomputes log‑probs from clamped probs instead of logits.This is the same numerical‑stability concern previously raised: after computing
y_log_predvialog_softmax/logsigmoid, the single‑channel path overwrites it with:if y_pred.shape[1] == 1: y_pred = torch.cat([1 - y_pred, y_pred], dim=1) y_log_pred = torch.log(torch.clamp(y_pred, 1e-7, 1.0))For large‑magnitude logits,
log(1 - sigmoid(z))via1 - psuffers cancellation and is capped by the clamp, whereas the stable form from logits is:
log p1 = F.logsigmoid(z)log p0 = F.logsigmoid(-z)Suggested refactor (keep logits and derive both channels from them):
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: - if self.use_softmax: - y_log_pred = F.log_softmax(y_pred, dim=1) - y_pred = torch.exp(y_log_pred) - else: - y_log_pred = F.logsigmoid(y_pred) - y_pred = torch.sigmoid(y_pred) - - if y_pred.shape[1] == 1: - y_pred = torch.cat([1 - y_pred, y_pred], dim=1) - y_log_pred = torch.log(torch.clamp(y_pred, 1e-7, 1.0)) + logits = y_pred + if self.use_softmax: + y_log_pred = F.log_softmax(logits, dim=1) + y_pred = torch.exp(y_log_pred) + else: + y_pred = torch.sigmoid(logits) + if logits.shape[1] == 1: + # Binary: compute stable log-probs for both channels from logits. + log_p1 = F.logsigmoid(logits) + log_p0 = F.logsigmoid(-logits) + y_log_pred = torch.cat([log_p0, log_p1], dim=1) + y_pred = torch.cat([1 - y_pred, y_pred], dim=1) + else: + y_log_pred = F.logsigmoid(logits) @@ - if y_pred.shape[1] == 1: - ... - y_true = torch.cat([1 - y_true, y_true], dim=1) + if y_pred.shape[1] == 1 and not self.to_onehot_y: + y_true = torch.cat([1 - y_true, y_true], dim=1)This keeps the stable log‑probability computation and avoids recomputing logs from clamped probabilities.
How does PyTorch recommend computing numerically stable log-probabilities for binary classification from logits (e.g., using `F.logsigmoid`)?Also applies to: 143-155
🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)
36-53:AsymmetricFocalTverskyLoss.__init__use_softmax wiring and docs mostly OK, epsilon doc could be clearer.
- Signature and stored
self.use_softmaxmatch the new behavior; docstrings align with the default values.- The epsilon line is grammatically odd and vague (
"similarly smooth value"). Consider rephrasing to something like:- epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. + epsilon : small smoothing constant to avoid division by zero. Defaults to 1e-7.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py(10 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (1)
monai/utils/enums.py (1)
LossReduction(253-264)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: packaging
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-py3 (3.10)
🔇 Additional comments (1)
monai/losses/unified_focal_loss.py (1)
187-236: I need to inspect the actual test file directly to verify coverage. Let me run a script:Given the context and findings from the shell command output provided earlier, I can now generate the rewritten review comment:Ensure test coverage spans binary vs. multi‑class scenarios with both
to_onehot_ymodes; verify handling of edge cases in composed sub‑losses.Tests in
tests/losses/test_unified_focal_loss.pyinclude binary segmentation cases (Case 0) and confirm constructor threading is sound. However, verify that:
- Test coverage explicitly covers multi-class (
C>1) scenarios alongside binary, with bothto_onehot_y=Falseandto_onehot_y=True.use_softmax=Truebehaves correctly for multi-class (expected) and gracefully for binary if a guard exists in sub‑losses.- Both
AsymmetricFocalLossandAsymmetricFocalTverskyLosshandle the full parameter matrix correctly when composed.
Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
monai/losses/unified_focal_loss.py (3)
61-81: Softmax/sigmoid gating and binary expansion look sound; consider warningstacklevel.The guard against
use_softmaxwith a single channel, theto_onehot_ypath, and the binary expansion to two channels are all consistent with the new multi-class/binary semantics. For better caller-facing diagnostics, consider addingstacklevel=2to thewarnings.warncall so the warning points at the user code rather than inside the loss implementation.
147-196: Logits-based CE and single-channel expansion are numerically stable and shape-consistent.Storing logits, using
F.log_softmax/F.logsigmoid, and derivinglog(1-p)viaF.logsigmoid(-y_logits)fixes the earlier numerical-stability problem when expanding a single-channel sigmoid output to two channels. The foreground aggregation viameanacross classes 1+ is also in line with the Tversky variant and the multi-class design.One minor clean-up you could consider is only computing
y_logitswhenn_pred_ch == 1, since it’s unused otherwise.
248-267: Unified loss reduction currently ignores thereductionargument for sub-losses.
AsymmetricUnifiedFocalLossalways constructsAsymmetricFocalLossandAsymmetricFocalTverskyLosswith their defaultreduction=LossReduction.MEAN, then applies another reduction on the combinedloss. This meansreduction="none"or"sum"onAsymmetricUnifiedFocalLosscan’t actually produce un-reduced per-sample outputs; you only ever see a scalar.If you want
reductionon this wrapper to behave like other MONAI losses, consider either:
- Passing
reduction=self.reductioninto both sub-loss constructors and treating their outputs as final (dropping the extra reduction here), or- Setting the sub-losses to
reduction=LossReduction.NONEand keeping the current final reduction logic solely in this class.Worth double-checking against existing callers and tests to ensure the intended behavior.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py(9 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
LossReduction(253-264)monai/networks/utils.py (1)
one_hot(170-220)
🪛 Ruff (0.14.5)
monai/losses/unified_focal_loss.py
65-65: Avoid specifying long messages outside the exception class
(TRY003)
69-69: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
83-83: Avoid specifying long messages outside the exception class
(TRY003)
151-151: Avoid specifying long messages outside the exception class
(TRY003)
155-155: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
177-177: Avoid specifying long messages outside the exception class
(TRY003)
237-237: Undefined name weight
(F821)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: packaging
- GitHub Check: flake8-py3 (mypy)
🔇 Additional comments (1)
monai/losses/unified_focal_loss.py (1)
88-107: Multi-class foreground/background aggregation is consistent and symmetric.Using class 0 as background and aggregating classes 1+ via
meankeeps scaling of the foreground term independent of the number of foreground classes and matches the updated asymmetric focal behavior. No issues from a correctness standpoint.
| def __init__( | ||
| self, | ||
| to_onehot_y: bool = False, | ||
| num_classes: int = 2, | ||
| weight: float = 0.5, | ||
| gamma: float = 0.5, | ||
| use_softmax: bool = False, | ||
| delta: float = 0.7, | ||
| gamma: float = 2, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| ): | ||
| """ | ||
| Args: | ||
| to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. | ||
| num_classes : number of classes, it only supports 2 now. Defaults to 2. | ||
| to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. | ||
| use_softmax: whether to use softmax to transform the original logits into probabilities. | ||
| If True, softmax is used. If False, sigmoid is used. Defaults to False. | ||
| delta : weight of the background. Defaults to 0.7. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. | ||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||
| weight : weight for each loss function, if it's none it's 0.5. Defaults to None. | ||
| Example: | ||
| >>> import torch | ||
| >>> from monai.losses import AsymmetricUnifiedFocalLoss | ||
| >>> pred = torch.ones((1,1,32,32), dtype=torch.float32) | ||
| >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) | ||
| >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) | ||
| >>> pred = torch.randn((1, 3, 32, 32)) | ||
| >>> grnd = torch.randint(0, 3, (1, 1, 32, 32)) | ||
| >>> fl = AsymmetricUnifiedFocalLoss(use_softmax=True, to_onehot_y=True) | ||
| >>> fl(pred, grnd) | ||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| self.to_onehot_y = to_onehot_y | ||
| self.num_classes = num_classes | ||
| self.weight: float = weight | ||
| self.gamma = gamma | ||
| self.delta = delta | ||
| self.weight: float = weight | ||
| self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) | ||
| self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) | ||
| self.use_softmax = use_softmax | ||
| self.asy_focal_loss = AsymmetricFocalLoss( | ||
| to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax | ||
| ) | ||
| self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( | ||
| to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weight is undefined in __init__ and docstring gamma default is stale.
AsymmetricUnifiedFocalLoss.__init__ doesn’t accept a weight argument but assigns self.weight: float = weight, which will raise at runtime. The docstring also still says gamma defaults to 0.75 while the signature uses 2, and it doesn’t document weight at all.
Recommend restoring weight as an explicit argument with a sane default and aligning the docstring with the implementation. For example:
- def __init__(
- self,
- to_onehot_y: bool = False,
- use_softmax: bool = False,
- delta: float = 0.7,
- gamma: float = 2,
- reduction: LossReduction | str = LossReduction.MEAN,
- ):
+ def __init__(
+ self,
+ to_onehot_y: bool = False,
+ use_softmax: bool = False,
+ delta: float = 0.7,
+ gamma: float = 2,
+ weight: float = 0.5,
+ reduction: LossReduction | str = LossReduction.MEAN,
+ ):
@@
- Args:
- to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
- use_softmax: whether to use softmax to transform the original logits into probabilities.
- If True, softmax is used. If False, sigmoid is used. Defaults to False.
- delta : weight of the background. Defaults to 0.7.
- gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
+ Args:
+ to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
+ use_softmax: whether to use softmax to transform the original logits into probabilities.
+ If True, softmax is used. If False, sigmoid is used. Defaults to False.
+ delta : weight of the background. Defaults to 0.7.
+ gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
+ weight: weight for combining the focal and focal-Tversky terms. Defaults to 0.5.
@@
- self.to_onehot_y = to_onehot_y
- self.weight: float = weight
+ self.to_onehot_y = to_onehot_y
+ self.weight: float = weight
@@
- self.asy_focal_loss = AsymmetricFocalLoss(
- to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax
- )
- self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
- to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax
- )
+ self.asy_focal_loss = AsymmetricFocalLoss(
+ to_onehot_y=to_onehot_y,
+ use_softmax=use_softmax,
+ delta=self.delta,
+ gamma=self.gamma,
+ )
+ self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
+ to_onehot_y=to_onehot_y,
+ use_softmax=use_softmax,
+ delta=self.delta,
+ gamma=self.gamma,
+ )(Last part just reorders keywords for clarity; behavior is unchanged.) This keeps the public API coherent and avoids a runtime NameError. As per coding guidelines, keeping docstrings in sync with signatures is important.
🧰 Tools
🪛 Ruff (0.14.5)
237-237: Undefined name weight
(F821)
🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 211 to 246, __init__
references an undefined variable weight and the docstring lists the wrong
default for gamma; restore weight as an explicit parameter (e.g., weight: float
= 1.0) in the signature and assign self.weight = weight with an appropriate type
annotation, update the docstring to document the new weight parameter and
correct the gamma default to match the signature (gamma=2), and ensure any
downstream uses of self.weight or parameter ordering are consistent with the
public API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)
226-226: Docstring gamma default is incorrect.Line 226 states "Defaults to 0.75" but the signature at line 216 uses
gamma: float = 2.- gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)
69-69: Addstacklevelto warning for better diagnostics.Without
stacklevel=2, the warning points to this line rather than the caller.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)Same applies to line 155.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py(9 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
LossReduction(253-264)monai/networks/utils.py (1)
one_hot(170-220)
🪛 Ruff (0.14.5)
monai/losses/unified_focal_loss.py
65-65: Avoid specifying long messages outside the exception class
(TRY003)
69-69: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
83-83: Avoid specifying long messages outside the exception class
(TRY003)
151-151: Avoid specifying long messages outside the exception class
(TRY003)
155-155: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
177-177: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-os (macOS-latest)
🔇 Additional comments (5)
monai/losses/unified_focal_loss.py (5)
36-59: LGTM!Constructor and docstrings are well-aligned. Parameters documented correctly.
122-145: LGTM!Constructor parameters and docstrings are consistent.
159-174: Numerical stability properly addressed.The use of
F.logsigmoid(-y_logits)for background log-probability (line 172) correctly maintains numerical stability. Good fix.
182-195: LGTM!Asymmetric focal weighting correctly applied: background gets
(1-p)^gammamodulation, foreground preserved. Aggregation consistent with Tversky loss.
250-270: LGTM!Forward delegation and reduction logic are correct. Docstring properly documents both shape options.
| # Class 1+ is Foreground | ||
| fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma) | ||
|
|
||
| if fore_dice.shape[1] > 1: | ||
| fore_dice = torch.mean(fore_dice, dim=1) | ||
| else: | ||
| fore_dice = fore_dice.squeeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential numerical issue when gamma > 1.
When gamma > 1, the exponent 1 - gamma becomes negative. If dice_class[:, 1:] approaches 1.0 (perfect prediction), (1 - dice_class) approaches 0, and 0^negative explodes.
Default gamma=0.75 is safe, but document or validate that gamma <= 1 is expected for this class.
+ if self.gamma > 1:
+ warnings.warn("gamma > 1 may cause numerical instability in AsymmetricFocalTverskyLoss")
+
# Class 1+ is Foreground
fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma)Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 97 to 103, the computation
fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma) can explode when
gamma > 1 because the exponent becomes negative and values of (1 - dice) near
zero lead to huge results; to fix this, validate or document gamma constraints
and add a guard: either assert or raise a clear ValueError if self.gamma > 1, or
clamp the base with a small epsilon (e.g., base = torch.clamp(1 - dice_class[:,
1:], min=eps)) before the pow, ensuring stable numerical behavior for negative
exponents; include a unit-test or a docstring note that gamma is expected <= 1
if you choose the validation route.
| self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( | ||
| to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Numerical instability: gamma=2 causes explosion in Tversky component.
AsymmetricFocalTverskyLoss computes (1 - dice)^(1 - gamma). With gamma=2 (default here), exponent = -1. When dice approaches 1.0 (near-perfect prediction), this explodes to infinity.
Consider using separate gamma values for each sub-loss or clamping the base in Tversky:
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
- to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax
+ to_onehot_y=to_onehot_y, gamma=0.75, delta=self.delta, use_softmax=use_softmax
)Or add clamping in AsymmetricFocalTverskyLoss.forward:
fore_dice_base = torch.clamp(1 - dice_class[:, 1:], min=self.epsilon)
fore_dice = torch.pow(fore_dice_base, 1 - self.gamma)🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 246-248, the
AsymmetricFocalTverskyLoss is instantiated with gamma=self.gamma which can make
the Tversky term explode when gamma>1; modify AsymmetricFocalTverskyLoss instead
to clamp the base before exponentiation (e.g., compute fore_dice_base =
torch.clamp(1 - dice_class[:, 1:], min=self.epsilon)) and then raise to the (1 -
gamma) power, and/or add an optional separate parameter (e.g., tversky_gamma or
epsilon) to the AsymmetricFocalTverskyLoss constructor so callers can pass
different gamma or epsilon values without changing unified_focal_loss.py. Ensure
the forward uses the clamp value and existing epsilon (or new param) before
calling torch.pow to prevent dividing/exponentiating by values near zero.
Signed-off-by: ytl0623 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)
97-103: Numerical instability when gamma > 1.
torch.pow(1 - dice_class[:, 1:], 1 - self.gamma)uses exponent1 - gamma. Default gamma=0.75 is safe, but if a user passes gamma > 1, the exponent becomes negative. When dice approaches 1.0 (near-perfect prediction),(1 - dice)approaches 0, and0^negativeexplodes.Add validation in
__init__:+ if gamma > 1: + raise ValueError(f"gamma must be <= 1 for AsymmetricFocalTverskyLoss, got {gamma}") self.gamma = gammaOr clamp the base before applying the power:
- fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma) + fore_dice_base = torch.clamp(1 - dice_class[:, 1:], min=self.epsilon) + fore_dice = torch.pow(fore_dice_base, 1 - self.gamma)
🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)
69-69: Missingstacklevelin warning.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
155-155: Missingstacklevelin warning.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py(9 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
LossReduction(253-264)monai/networks/utils.py (1)
one_hot(170-220)
🪛 Ruff (0.14.5)
monai/losses/unified_focal_loss.py
65-65: Avoid specifying long messages outside the exception class
(TRY003)
69-69: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
83-83: Avoid specifying long messages outside the exception class
(TRY003)
151-151: Avoid specifying long messages outside the exception class
(TRY003)
155-155: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
177-177: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (3)
monai/losses/unified_focal_loss.py (3)
159-174: Numerically stable single-channel expansion.Preserving logits and using
F.logsigmoid(-y_logits)for background log-probabilities is the correct approach. This avoidslog(1 - sigmoid(x))instability.
243-254: Sub-loss wiring looks correct.Both sub-losses receive
use_softmaxparameter, ensuring consistent probability transformation across the unified loss.
270-276: Based on my investigation, I was unable to directly access the MONAI repository code due to cloning failures. However, my web search findings reveal important context about how MONAI losses work:
- MONAI losses support reduction parameters with options: "none", "mean", "sum"
- Sub-losses (TverskyLoss, FocalLoss) can compute per-sample losses when initialized with
reduction="none"- The original Unified Focal Loss implementation computes per-sample losses that require reduction
The critical issue is that I cannot verify the actual code at lines 265-276 to determine:
- Whether sub-losses are instantiated with
reduction="none"(returning per-sample tensors) or default reduction (returning scalars)- Whether line 268's
lossvariable is actually a scalar or a per-sample tensorThe reviewer's concern is structurally valid (applying
torch.sum()andtorch.mean()to scalars is indeed a no-op), but the premise cannot be confirmed without code inspection.
Reduction logic requires verification against actual sub-loss instantiation. The concern raised is structurally sound — if sub-losses use
reduction="mean"by default, applyingtorch.sum()ortorch.mean()at lines 271–275 would be redundant no-ops on scalars. However, if sub-losses return per-sample tensors (viareduction="none"), this reduction logic is necessary. Verify: (1) how AsymmetricFocalLoss and TverskyLoss are initialized in lines ~260–268, (2) whether their reduction parameter is explicitly set or uses defaults, and (3) the actual shape oflossat line 268.
…nifiedFocalLoss
Fixes #8603 .
Description
use_softmaxargument: Allows the loss function to accept raw logits and handle activation (sigmoid or softmax) internally for better numerical stability.num_classes) and simplified the forward logic to use vectorized operations.use_softmaxmodes andmulti-classscenarios using logits input.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.