Skip to content

Conversation

@ytl0623
Copy link
Contributor

@ytl0623 ytl0623 commented Nov 25, 2025

…nifiedFocalLoss

Fixes #8603 .

Description

  • Added use_softmax argument: Allows the loss function to accept raw logits and handle activation (sigmoid or softmax) internally for better numerical stability.
  • Multi-class Support: Removed hardcoded binary segmentation restrictions. The loss now supports multi-channel inputs (Channel 0 as Background, Channels 1+ as Foreground).
  • Refactoring: Removed redundant arguments (e.g., num_classes) and simplified the forward logic to use vectorized operations.
  • Tests: Cover use_softmax modes and multi-class scenarios using logits input.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 25, 2025

Walkthrough

Adds a use_softmax flag to AsymmetricFocalTverskyLoss, AsymmetricFocalLoss, and AsymmetricUnifiedFocalLoss to switch between softmax (multi-class) and sigmoid (binary) probability transforms. Forward paths now apply softmax/log_softmax or sigmoid/log-sigmoid depending on use_softmax, expand single-channel predictions when appropriate, compute losses per class with separate background/foreground handling, and raise if use_softmax is requested for single-channel input. AsymmetricUnifiedFocalLoss drops num_classes, wires use_softmax and to_onehot_y into sub-losses, and tests updated for binary and multi-class cases.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Mixed API and logic changes across multiple loss classes and tests.
  • Review focus:
    • Correctness of channel/class dimension handling and single-channel expansion.
    • Proper propagation of use_softmax and to_onehot_y into sub-losses.
    • Use of softmax/log_softmax vs sigmoid/log-sigmoid and numerical stability (epsilons).
    • Removal of num_classes from AsymmetricUnifiedFocalLoss and test updates in tests/losses/test_unified_focal_loss.py.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive Title is truncated and vague, using ellipsis without completing the actual objective. Complete the title to fully convey the main change, e.g., 'Add sigmoid/softmax support and multi-class extension for AsymmetricUnifiedFocalLoss'
✅ Passed checks (3 passed)
Check name Status Explanation
Description check ✅ Passed Description covers objectives but required checklist items remain unchecked despite being applicable.
Linked Issues check ✅ Passed PR addresses all coding objectives from #8603: use_softmax support, multi-class extension, and refactoring to reuse existing loss functions.
Out of Scope Changes check ✅ Passed All changes align with #8603 objectives; modifications are scoped to AsymmetricFocalLoss variants and their tests.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (1)

243-244: Redundant shape check.

Both AsymmetricFocalLoss and AsymmetricFocalTverskyLoss already validate shapes. This check in the parent is redundant when to_onehot_y=False. When to_onehot_y=True, y_true shape changes after one-hot encoding in sub-losses, making this pre-check potentially incorrect.

Consider removing this check or adjusting it to account for the to_onehot_y case:

-        if y_pred.shape != y_true.shape:
-            raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
+        if not self.to_onehot_y and y_pred.shape != y_true.shape:
+            raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
tests/losses/test_unified_focal_loss.py (1)

22-50: Test coverage looks reasonable but could be expanded.

Cases cover:

  • Binary segmentation with default params
  • Binary segmentation with explicit sigmoid
  • Multi-class with softmax + one-hot

Missing coverage for edge cases:

  • use_softmax=True with binary (single foreground channel)
  • Multi-class with use_softmax=False (sigmoid)
  • Non-zero expected loss values to verify magnitude correctness

Consider adding at least one test case with a non-zero expected loss to verify the loss magnitude calculation:

# Case 3: Verify non-trivial loss value
[
    {},
    {
        "y_pred": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]], [[[0.0, 0.0], [0.0, 0.0]]]]),
        "y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
    },
    # expected_val: compute expected non-zero value
],
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between f493ecd and d724a95.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py (11 hunks)
  • tests/losses/test_unified_focal_loss.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (179-257)
monai/losses/unified_focal_loss.py (1)
monai/utils/enums.py (1)
  • LossReduction (253-264)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: build-docs
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: packaging
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-os (macOS-latest)
🔇 Additional comments (5)
monai/losses/unified_focal_loss.py (3)

17-17: LGTM!

Import addition for torch.nn.functional is necessary for the new log_softmax and logsigmoid operations.


143-148: LGTM!

Good numerical stability practice using log_softmax/logsigmoid instead of computing log(softmax(...)).


170-175: Verify intentional asymmetry between background and foreground aggregation.

back_ce is a single channel value while fore_ce sums across all foreground channels before stacking. This means multi-class scenarios weight foreground loss by number of classes. Confirm this matches the paper's intended behavior.

tests/losses/test_unified_focal_loss.py (2)

55-59: LGTM!

Parameterized test properly accepts constructor arguments and validates loss output.


66-75: LGTM!

CUDA test properly validates loss computation on GPU with updated input values.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)

36-44: reduction parameter accepted but ignored.

AsymmetricFocalTverskyLoss.__init__ accepts reduction (line 43) and stores it (line 54), but forward() always returns torch.mean() (line 103) regardless of the reduction setting. Same issue in AsymmetricFocalLoss. Either implement reduction in these sub-losses or remove the parameter.

 def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
     # ... existing logic ...
-    loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
-    return loss
+    loss = torch.stack([back_dice, fore_dice], dim=-1)
+    if self.reduction == LossReduction.SUM.value:
+        return torch.sum(loss)
+    if self.reduction == LossReduction.NONE.value:
+        return loss
+    if self.reduction == LossReduction.MEAN.value:
+        return torch.mean(loss)
+    raise ValueError(f'Unsupported reduction: {self.reduction}')

Also applies to: 103-103

♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)

134-134: Docstring default value incorrect.

States gamma defaults to 0.75, but actual default is 2.

-            gamma : value of the exponent gamma in the definition of the Focal loss  . Defaults to 0.75.
+            gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between d6e4335 and 8731b30.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py (11 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (1)
monai/utils/enums.py (1)
  • LossReduction (253-264)
🔇 Additional comments (4)
monai/losses/unified_focal_loss.py (4)

17-17: LGTM—Import required for log probability functions.


28-66: LGTM—Activation logic and parameter addition correct.

Typo fix and multi-class support properly documented. Applying activation before validation is correct since inputs are expected to be logits.


145-151: LGTM—Log probability handling is numerically stable.

Using log_softmax/logsigmoid with separate log and probability tensors avoids numerical issues.


225-230: LGTM—Sub-losses properly configured with use_softmax.

Parameters correctly forwarded to both sub-losses.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)

180-183: Aggregation inconsistency persists.

torch.sum here vs torch.mean at line 99 in AsymmetricFocalTverskyLoss. Loss magnitude scales differently with class count.

🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)

67-77: Warning check is unreachable after single-channel expansion.

After lines 67-69 expand single-channel to two-class, n_pred_ch at line 71 will always be ≥2. The warning at lines 74-75 becomes dead code for originally single-channel inputs.

Consider moving the warning check before the expansion, or remove it since the expansion handles single-channel inputs.

+        n_pred_ch = y_pred.shape[1]
+        if self.to_onehot_y and n_pred_ch == 1:
+            warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+
         if y_pred.shape[1] == 1:
             y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
             y_true = torch.cat([1 - y_true, y_true], dim=1)

-        n_pred_ch = y_pred.shape[1]
-
-        if self.to_onehot_y:
-            if n_pred_ch == 1:
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
-            else:
-                y_true = one_hot(y_true, num_classes=n_pred_ch)
+        if self.to_onehot_y and y_pred.shape[1] > 1:
+            y_true = one_hot(y_true, num_classes=y_pred.shape[1])

254-260: Reduction logic has no effect.

Sub-losses at lines 249-250 return scalars (they use their own reduction). The outer reduction here (lines 254-260) on a scalar is effectively a no-op. Consider passing reduction to sub-losses or documenting this behavior.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between ea8a6ee and 6c77725.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py (10 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: build-docs
  • GitHub Check: packaging
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.9)
🔇 Additional comments (1)
monai/losses/unified_focal_loss.py (1)

233-238: LGTM!

Sub-losses correctly wired with shared parameters.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (3)
monai/losses/unified_focal_loss.py (3)

36-80: Fix single‑channel + to_onehot_y ordering and clamp foreground dice base.

Two points here:

  1. Bug: to_onehot_y after single‑channel expansion.
    In forward, you expand 1‑channel predictions/targets before computing n_pred_ch and applying to_onehot_y:

    • For binary cases with to_onehot_y=True, this now calls one_hot() on a 2‑channel float tensor (y_true = cat([1 - y_true, y_true], ...)) instead of the original label map, and it also defeats the previous “single channel prediction, to_onehot_y=True ignored.” behaviour.

    Reorder so n_pred_ch and to_onehot_y are evaluated on the original channel count, then expand to two channels:

  •    if self.use_softmax:
    
  •        y_pred = torch.softmax(y_pred, dim=1)
    
  •    else:
    
  •        y_pred = torch.sigmoid(y_pred)
    
  •    if y_pred.shape[1] == 1:
    
  •        y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
    
  •        y_true = torch.cat([1 - y_true, y_true], dim=1)
    
  •    n_pred_ch = y_pred.shape[1]
    
  •    if self.to_onehot_y:
    
  •        if n_pred_ch == 1:
    
  •            warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
    
  •        else:
    
  •            y_true = one_hot(y_true, num_classes=n_pred_ch)
    
  •    if self.use_softmax:
    
  •        y_pred = torch.softmax(y_pred, dim=1)
    
  •    else:
    
  •        y_pred = torch.sigmoid(y_pred)
    
  •    n_pred_ch = y_pred.shape[1]
    
  •    if self.to_onehot_y:
    
  •        if n_pred_ch == 1:
    
  •            warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
    
  •        else:
    
  •            y_true = one_hot(y_true, num_classes=n_pred_ch)
    
  •    if y_pred.shape[1] == 1:
    
  •        y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
    
  •        y_true = torch.cat([1 - y_true, y_true], dim=1)
    
    
    
  1. Numeric stability: clamp base before power.
    fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma) can still explode if a user passes gamma > 1 and dice_class[:, 1:] → 1 (negative exponent on values near zero). To guard this, clamp the base:

  •    fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma)
    
  •    fore_dice_base = torch.clamp(1 - dice_class[:, 1:], min=self.epsilon)
    
  •    fore_dice = torch.pow(fore_dice_base, 1 - self.gamma)
    
    
    
    
    
    
    
    
    

Also applies to: 84-99


118-185: Reorder to_onehot_y, fix single‑channel log‑prob stability, and consider foreground aggregation.

This block has three intertwined issues:

  1. Bug: to_onehot_y after single‑channel expansion (same as Tversky).
    Expanding y_pred/y_true to two channels before computing n_pred_ch and applying to_onehot_y means that for binary cases with to_onehot_y=True you now call one_hot() on an already expanded 2‑channel float tensor instead of the original label map, and you lose the prior “ignore + warn” semantics.

    Reorder to evaluate to_onehot_y on the original channel count, then expand:

  • def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
  •    if self.use_softmax:
    
  •        y_log_pred = F.log_softmax(y_pred, dim=1)
    
  •        y_pred = torch.exp(y_log_pred)
    
  •    else:
    
  •        y_log_pred = F.logsigmoid(y_pred)
    
  •        y_pred = torch.sigmoid(y_pred)
    
  •    if y_pred.shape[1] == 1:
    
  •        y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
    
  •        y_log_pred = torch.log(torch.clamp(y_pred, 1e-7, 1.0))
    
  •        y_true = torch.cat([1 - y_true, y_true], dim=1)
    
  •    n_pred_ch = y_pred.shape[1]
    
  •    if self.to_onehot_y:
    
  •        if n_pred_ch == 1:
    
  •            warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
    
  •        else:
    
  •            y_true = one_hot(y_true, num_classes=n_pred_ch)
    
  • def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
  •    logits = y_pred
    
  •    if self.use_softmax:
    
  •        y_log_pred = F.log_softmax(logits, dim=1)
    
  •        y_pred = torch.exp(y_log_pred)
    
  •    else:
    
  •        y_log_pred = F.logsigmoid(logits)
    
  •        y_pred = torch.sigmoid(logits)
    
  •    n_pred_ch = y_pred.shape[1]
    
  •    if self.to_onehot_y:
    
  •        if n_pred_ch == 1:
    
  •            warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
    
  •        else:
    
  •            y_true = one_hot(y_true, num_classes=n_pred_ch)
    
  •    if y_pred.shape[1] == 1:
    
  •        if self.use_softmax:
    
  •            # Softmax over a single channel is degenerate; fall back to prob-based log.
    
  •            y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
    
  •            y_log_pred = torch.log(torch.clamp(y_pred, self.epsilon, 1.0))
    
  •        else:
    
  •            pos_prob = y_pred
    
  •            neg_prob = 1 - pos_prob
    
  •            pos_log_prob = y_log_pred
    
  •            neg_log_prob = F.logsigmoid(-logits)
    
  •            y_pred = torch.cat([neg_prob, pos_prob], dim=1)
    
  •            y_log_pred = torch.cat([neg_log_prob, pos_log_prob], dim=1)
    
  •        y_true = torch.cat([1 - y_true, y_true], dim=1)
    
    
    
  1. Numeric stability: don’t re‑log probabilities.
    The change above also addresses the prior issue where you recomputed y_log_pred = log(clamp(y_pred)) in the 1‑channel path, discarding the stable log_softmax/logsigmoid computed from logits. For the binary sigmoid case we now derive log p and log(1 - p) directly from logits using F.logsigmoid, which is the numerically stable form.

  2. Foreground CE aggregation vs Tversky (optional but recommended).
    Here foreground CE aggregates with sum over foreground classes:

    fore_ce = cross_entropy[:, 1:]
    fore_ce = self.delta * fore_ce
    if fore_ce.shape[1] > 1:
        fore_ce = torch.sum(fore_ce, dim=1)

    while AsymmetricFocalTverskyLoss uses mean across foreground classes. This keeps the unified loss scale tied to the number of foreground classes. If you want scale invariance w.r.t. class count and consistency with Tversky, switching to mean is cleaner:

  •    if fore_ce.shape[1] > 1:
    
  •        fore_ce = torch.sum(fore_ce, dim=1)
    
  •    if fore_ce.shape[1] > 1:
    
  •        fore_ce = torch.mean(fore_ce, dim=1)
    
    
    If the `sum` behaviour is intentional, consider documenting the choice and adding a unit test that asserts the scaling.
    
    

211-216: Clarify y_true shape for to_onehot_y=True in docstring.

The forward docstring still says y_true is always BNH[WD], but with to_onehot_y=True you expect label maps that are converted to one‑hot internally. This matches an earlier comment and is still unresolved.

Consider updating as:

         Args:
             y_pred : the shape should be BNH[WD], where N is the number of classes.
                 The input should be the original logits since it will be transformed by
-                    a sigmoid/softmax in the forward function.
-            y_true : the shape should be BNH[WD], where N is the number of classes.
+                    a sigmoid/softmax in the forward function.
+            y_true : the shape should be BNH[WD] (one-hot format), or B1H[WD] with class
+                indices when ``to_onehot_y=True``.

Also applies to: 241-246

🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)

47-53: Tighten epsilon docstrings to state purpose.

The epsilon descriptions are grammatically off and don’t say what epsilon is for. A more direct wording would help:

-            epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
+            epsilon : small constant for numerical stability (e.g., to avoid division by
+                zero or log(0)). Defaults to 1e-7.

Apply similarly to both AsymmetricFocalTverskyLoss and AsymmetricFocalLoss.

Also applies to: 129-136

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between c793454 and 3b7277f.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py (9 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (1)
monai/utils/enums.py (1)
  • LossReduction (253-264)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: packaging
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (codeformat)
🔇 Additional comments (1)
monai/losses/unified_focal_loss.py (1)

200-237: Constructor wiring for use_softmax looks correct; ensure tests cover both modes.

The new use_softmax flag is correctly threaded from AsymmetricUnifiedFocalLoss into AsymmetricFocalLoss and AsymmetricFocalTverskyLoss along with to_onehot_y, gamma, and delta. That matches the intended API.

Please make sure tests exercise at least:

  • binary and multi-class cases with use_softmax=True (logits + one-hot/label inputs),
  • binary sigmoid cases with use_softmax=False (including single-channel outputs),
  • to_onehot_y both True and False in combination with the above.

Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)

61-80: Fix two critical issues with single-channel handling and to_onehot_y semantics in both AsymmetricFocalTverskyLoss and AsymmetricFocalLoss.

Verification confirms both issues:

  1. Issue 1 verified: When y_pred.shape[1]==1 with to_onehot_y=True, one_hot() is called on an already 2-channel expanded y_true (containing [1-y_true, y_true]). This violates the one_hot contract—it expects scalar class indices, not probabilities. Test cases do not cover this scenario.

  2. Issue 2 verified: When use_softmax=True with single-channel input, softmax(x, dim=1) returns all 1.0, then concatenation produces exactly [0, 1] everywhere, completely discarding input logits. Test cases do not cover this either.

Apply the suggested refactoring pattern to both classes:

  • Perform to_onehot_y conversion before expanding channels
  • Explicitly disallow use_softmax=True with single-channel predictions (raise ValueError or strong warning)
  • Add tests for both edge cases post-fix
♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)

118-135: Single‑channel AsymmetricFocalLoss recomputes log‑probs from clamped probs instead of logits.

This is the same numerical‑stability concern previously raised: after computing y_log_pred via log_softmax/logsigmoid, the single‑channel path overwrites it with:

if y_pred.shape[1] == 1:
    y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
    y_log_pred = torch.log(torch.clamp(y_pred, 1e-7, 1.0))

For large‑magnitude logits, log(1 - sigmoid(z)) via 1 - p suffers cancellation and is capped by the clamp, whereas the stable form from logits is:

  • log p1 = F.logsigmoid(z)
  • log p0 = F.logsigmoid(-z)

Suggested refactor (keep logits and derive both channels from them):

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
-    if self.use_softmax:
-        y_log_pred = F.log_softmax(y_pred, dim=1)
-        y_pred = torch.exp(y_log_pred)
-    else:
-        y_log_pred = F.logsigmoid(y_pred)
-        y_pred = torch.sigmoid(y_pred)
-
-    if y_pred.shape[1] == 1:
-        y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
-        y_log_pred = torch.log(torch.clamp(y_pred, 1e-7, 1.0))
+    logits = y_pred
+    if self.use_softmax:
+        y_log_pred = F.log_softmax(logits, dim=1)
+        y_pred = torch.exp(y_log_pred)
+    else:
+        y_pred = torch.sigmoid(logits)
+        if logits.shape[1] == 1:
+            # Binary: compute stable log-probs for both channels from logits.
+            log_p1 = F.logsigmoid(logits)
+            log_p0 = F.logsigmoid(-logits)
+            y_log_pred = torch.cat([log_p0, log_p1], dim=1)
+            y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
+        else:
+            y_log_pred = F.logsigmoid(logits)
@@
-    if y_pred.shape[1] == 1:
-        ...
-        y_true = torch.cat([1 - y_true, y_true], dim=1)
+    if y_pred.shape[1] == 1 and not self.to_onehot_y:
+        y_true = torch.cat([1 - y_true, y_true], dim=1)

This keeps the stable log‑probability computation and avoids recomputing logs from clamped probabilities.

How does PyTorch recommend computing numerically stable log-probabilities for binary classification from logits (e.g., using `F.logsigmoid`)?

Also applies to: 143-155

🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)

36-53: AsymmetricFocalTverskyLoss.__init__ use_softmax wiring and docs mostly OK, epsilon doc could be clearer.

  • Signature and stored self.use_softmax match the new behavior; docstrings align with the default values.
  • The epsilon line is grammatically odd and vague ("similarly smooth value"). Consider rephrasing to something like:
-            epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
+            epsilon : small smoothing constant to avoid division by zero. Defaults to 1e-7.
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 0640d0f and 6c25189.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py (10 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (1)
monai/utils/enums.py (1)
  • LossReduction (253-264)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: build-docs
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: packaging
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-py3 (3.10)
🔇 Additional comments (1)
monai/losses/unified_focal_loss.py (1)

187-236: I need to inspect the actual test file directly to verify coverage. Let me run a script:Given the context and findings from the shell command output provided earlier, I can now generate the rewritten review comment:

Ensure test coverage spans binary vs. multi‑class scenarios with both to_onehot_y modes; verify handling of edge cases in composed sub‑losses.

Tests in tests/losses/test_unified_focal_loss.py include binary segmentation cases (Case 0) and confirm constructor threading is sound. However, verify that:

  • Test coverage explicitly covers multi-class (C>1) scenarios alongside binary, with both to_onehot_y=False and to_onehot_y=True.
  • use_softmax=True behaves correctly for multi-class (expected) and gracefully for binary if a guard exists in sub‑losses.
  • Both AsymmetricFocalLoss and AsymmetricFocalTverskyLoss handle the full parameter matrix correctly when composed.

Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
monai/losses/unified_focal_loss.py (3)

61-81: Softmax/sigmoid gating and binary expansion look sound; consider warning stacklevel.

The guard against use_softmax with a single channel, the to_onehot_y path, and the binary expansion to two channels are all consistent with the new multi-class/binary semantics. For better caller-facing diagnostics, consider adding stacklevel=2 to the warnings.warn call so the warning points at the user code rather than inside the loss implementation.


147-196: Logits-based CE and single-channel expansion are numerically stable and shape-consistent.

Storing logits, using F.log_softmax / F.logsigmoid, and deriving log(1-p) via F.logsigmoid(-y_logits) fixes the earlier numerical-stability problem when expanding a single-channel sigmoid output to two channels. The foreground aggregation via mean across classes 1+ is also in line with the Tversky variant and the multi-class design.

One minor clean-up you could consider is only computing y_logits when n_pred_ch == 1, since it’s unused otherwise.


248-267: Unified loss reduction currently ignores the reduction argument for sub-losses.

AsymmetricUnifiedFocalLoss always constructs AsymmetricFocalLoss and AsymmetricFocalTverskyLoss with their default reduction=LossReduction.MEAN, then applies another reduction on the combined loss. This means reduction="none" or "sum" on AsymmetricUnifiedFocalLoss can’t actually produce un-reduced per-sample outputs; you only ever see a scalar.

If you want reduction on this wrapper to behave like other MONAI losses, consider either:

  • Passing reduction=self.reduction into both sub-loss constructors and treating their outputs as final (dropping the extra reduction here), or
  • Setting the sub-losses to reduction=LossReduction.NONE and keeping the current final reduction logic solely in this class.

Worth double-checking against existing callers and tests to ensure the intended behavior.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 6c25189 and e0e48a3.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py (9 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/networks/utils.py (1)
  • one_hot (170-220)
🪛 Ruff (0.14.5)
monai/losses/unified_focal_loss.py

65-65: Avoid specifying long messages outside the exception class

(TRY003)


69-69: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


83-83: Avoid specifying long messages outside the exception class

(TRY003)


151-151: Avoid specifying long messages outside the exception class

(TRY003)


155-155: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


177-177: Avoid specifying long messages outside the exception class

(TRY003)


237-237: Undefined name weight

(F821)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: packaging
  • GitHub Check: flake8-py3 (mypy)
🔇 Additional comments (1)
monai/losses/unified_focal_loss.py (1)

88-107: Multi-class foreground/background aggregation is consistent and symmetric.

Using class 0 as background and aggregating classes 1+ via mean keeps scaling of the foreground term independent of the number of foreground classes and matches the updated asymmetric focal behavior. No issues from a correctness standpoint.

Comment on lines 157 to 246
def __init__(
self,
to_onehot_y: bool = False,
num_classes: int = 2,
weight: float = 0.5,
gamma: float = 0.5,
use_softmax: bool = False,
delta: float = 0.7,
gamma: float = 2,
reduction: LossReduction | str = LossReduction.MEAN,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
num_classes : number of classes, it only supports 2 now. Defaults to 2.
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
use_softmax: whether to use softmax to transform the original logits into probabilities.
If True, softmax is used. If False, sigmoid is used. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
Example:
>>> import torch
>>> from monai.losses import AsymmetricUnifiedFocalLoss
>>> pred = torch.ones((1,1,32,32), dtype=torch.float32)
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
>>> pred = torch.randn((1, 3, 32, 32))
>>> grnd = torch.randint(0, 3, (1, 1, 32, 32))
>>> fl = AsymmetricUnifiedFocalLoss(use_softmax=True, to_onehot_y=True)
>>> fl(pred, grnd)
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.num_classes = num_classes
self.weight: float = weight
self.gamma = gamma
self.delta = delta
self.weight: float = weight
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
self.use_softmax = use_softmax
self.asy_focal_loss = AsymmetricFocalLoss(
to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax
)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

weight is undefined in __init__ and docstring gamma default is stale.

AsymmetricUnifiedFocalLoss.__init__ doesn’t accept a weight argument but assigns self.weight: float = weight, which will raise at runtime. The docstring also still says gamma defaults to 0.75 while the signature uses 2, and it doesn’t document weight at all.

Recommend restoring weight as an explicit argument with a sane default and aligning the docstring with the implementation. For example:

-    def __init__(
-        self,
-        to_onehot_y: bool = False,
-        use_softmax: bool = False,
-        delta: float = 0.7,
-        gamma: float = 2,
-        reduction: LossReduction | str = LossReduction.MEAN,
-    ):
+    def __init__(
+        self,
+        to_onehot_y: bool = False,
+        use_softmax: bool = False,
+        delta: float = 0.7,
+        gamma: float = 2,
+        weight: float = 0.5,
+        reduction: LossReduction | str = LossReduction.MEAN,
+    ):
@@
-        Args:
-            to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
-            use_softmax: whether to use softmax to transform the original logits into probabilities.
-                If True, softmax is used. If False, sigmoid is used. Defaults to False.
-            delta : weight of the background. Defaults to 0.7.
-            gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
+        Args:
+            to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
+            use_softmax: whether to use softmax to transform the original logits into probabilities.
+                If True, softmax is used. If False, sigmoid is used. Defaults to False.
+            delta : weight of the background. Defaults to 0.7.
+            gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
+            weight: weight for combining the focal and focal-Tversky terms. Defaults to 0.5.
@@
-        self.to_onehot_y = to_onehot_y
-        self.weight: float = weight
+        self.to_onehot_y = to_onehot_y
+        self.weight: float = weight
@@
-        self.asy_focal_loss = AsymmetricFocalLoss(
-            to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax
-        )
-        self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
-            to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax
-        )
+        self.asy_focal_loss = AsymmetricFocalLoss(
+            to_onehot_y=to_onehot_y,
+            use_softmax=use_softmax,
+            delta=self.delta,
+            gamma=self.gamma,
+        )
+        self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
+            to_onehot_y=to_onehot_y,
+            use_softmax=use_softmax,
+            delta=self.delta,
+            gamma=self.gamma,
+        )

(Last part just reorders keywords for clarity; behavior is unchanged.) This keeps the public API coherent and avoids a runtime NameError. As per coding guidelines, keeping docstrings in sync with signatures is important.

🧰 Tools
🪛 Ruff (0.14.5)

237-237: Undefined name weight

(F821)

🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 211 to 246, __init__
references an undefined variable weight and the docstring lists the wrong
default for gamma; restore weight as an explicit parameter (e.g., weight: float
= 1.0) in the signature and assign self.weight = weight with an appropriate type
annotation, update the docstring to document the new weight parameter and
correct the gamma default to match the signature (gamma=2), and ensure any
downstream uses of self.weight or parameter ordering are consistent with the
public API.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)

226-226: Docstring gamma default is incorrect.

Line 226 states "Defaults to 0.75" but the signature at line 216 uses gamma: float = 2.

-            gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
+            gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)

69-69: Add stacklevel to warning for better diagnostics.

Without stacklevel=2, the warning points to this line rather than the caller.

-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)

Same applies to line 155.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between e0e48a3 and 954df60.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py (9 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/networks/utils.py (1)
  • one_hot (170-220)
🪛 Ruff (0.14.5)
monai/losses/unified_focal_loss.py

65-65: Avoid specifying long messages outside the exception class

(TRY003)


69-69: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


83-83: Avoid specifying long messages outside the exception class

(TRY003)


151-151: Avoid specifying long messages outside the exception class

(TRY003)


155-155: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


177-177: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: build-docs
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-os (macOS-latest)
🔇 Additional comments (5)
monai/losses/unified_focal_loss.py (5)

36-59: LGTM!

Constructor and docstrings are well-aligned. Parameters documented correctly.


122-145: LGTM!

Constructor parameters and docstrings are consistent.


159-174: Numerical stability properly addressed.

The use of F.logsigmoid(-y_logits) for background log-probability (line 172) correctly maintains numerical stability. Good fix.


182-195: LGTM!

Asymmetric focal weighting correctly applied: background gets (1-p)^gamma modulation, foreground preserved. Aggregation consistent with Tversky loss.


250-270: LGTM!

Forward delegation and reduction logic are correct. Docstring properly documents both shape options.

Comment on lines +97 to +103
# Class 1+ is Foreground
fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma)

if fore_dice.shape[1] > 1:
fore_dice = torch.mean(fore_dice, dim=1)
else:
fore_dice = fore_dice.squeeze(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential numerical issue when gamma > 1.

When gamma > 1, the exponent 1 - gamma becomes negative. If dice_class[:, 1:] approaches 1.0 (perfect prediction), (1 - dice_class) approaches 0, and 0^negative explodes.

Default gamma=0.75 is safe, but document or validate that gamma <= 1 is expected for this class.

+        if self.gamma > 1:
+            warnings.warn("gamma > 1 may cause numerical instability in AsymmetricFocalTverskyLoss")
+        
         # Class 1+ is Foreground
         fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma)

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 97 to 103, the computation
fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma) can explode when
gamma > 1 because the exponent becomes negative and values of (1 - dice) near
zero lead to huge results; to fix this, validate or document gamma constraints
and add a guard: either assert or raise a clear ValueError if self.gamma > 1, or
clamp the base with a small epsilon (e.g., base = torch.clamp(1 - dice_class[:,
1:], min=eps)) before the pow, ensuring stable numerical behavior for negative
exponents; include a unit-test or a docstring note that gamma is expected <= 1
if you choose the validation route.

Comment on lines 246 to 248
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Numerical instability: gamma=2 causes explosion in Tversky component.

AsymmetricFocalTverskyLoss computes (1 - dice)^(1 - gamma). With gamma=2 (default here), exponent = -1. When dice approaches 1.0 (near-perfect prediction), this explodes to infinity.

Consider using separate gamma values for each sub-loss or clamping the base in Tversky:

 self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
-    to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax
+    to_onehot_y=to_onehot_y, gamma=0.75, delta=self.delta, use_softmax=use_softmax
 )

Or add clamping in AsymmetricFocalTverskyLoss.forward:

fore_dice_base = torch.clamp(1 - dice_class[:, 1:], min=self.epsilon)
fore_dice = torch.pow(fore_dice_base, 1 - self.gamma)
🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 246-248, the
AsymmetricFocalTverskyLoss is instantiated with gamma=self.gamma which can make
the Tversky term explode when gamma>1; modify AsymmetricFocalTverskyLoss instead
to clamp the base before exponentiation (e.g., compute fore_dice_base =
torch.clamp(1 - dice_class[:, 1:], min=self.epsilon)) and then raise to the (1 -
gamma) power, and/or add an optional separate parameter (e.g., tversky_gamma or
epsilon) to the AsymmetricFocalTverskyLoss constructor so callers can pass
different gamma or epsilon values without changing unified_focal_loss.py. Ensure
the forward uses the clamp value and existing epsilon (or new param) before
calling torch.pow to prevent dividing/exponentiating by values near zero.

Signed-off-by: ytl0623 <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)

97-103: Numerical instability when gamma > 1.

torch.pow(1 - dice_class[:, 1:], 1 - self.gamma) uses exponent 1 - gamma. Default gamma=0.75 is safe, but if a user passes gamma > 1, the exponent becomes negative. When dice approaches 1.0 (near-perfect prediction), (1 - dice) approaches 0, and 0^negative explodes.

Add validation in __init__:

+        if gamma > 1:
+            raise ValueError(f"gamma must be <= 1 for AsymmetricFocalTverskyLoss, got {gamma}")
         self.gamma = gamma

Or clamp the base before applying the power:

-        fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma)
+        fore_dice_base = torch.clamp(1 - dice_class[:, 1:], min=self.epsilon)
+        fore_dice = torch.pow(fore_dice_base, 1 - self.gamma)
🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)

69-69: Missing stacklevel in warning.

-            warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+            warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)

155-155: Missing stacklevel in warning.

-            warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+            warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 954df60 and ee2215a.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py (9 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/networks/utils.py (1)
  • one_hot (170-220)
🪛 Ruff (0.14.5)
monai/losses/unified_focal_loss.py

65-65: Avoid specifying long messages outside the exception class

(TRY003)


69-69: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


83-83: Avoid specifying long messages outside the exception class

(TRY003)


151-151: Avoid specifying long messages outside the exception class

(TRY003)


155-155: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


177-177: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (3)
monai/losses/unified_focal_loss.py (3)

159-174: Numerically stable single-channel expansion.

Preserving logits and using F.logsigmoid(-y_logits) for background log-probabilities is the correct approach. This avoids log(1 - sigmoid(x)) instability.


243-254: Sub-loss wiring looks correct.

Both sub-losses receive use_softmax parameter, ensuring consistent probability transformation across the unified loss.


270-276: Based on my investigation, I was unable to directly access the MONAI repository code due to cloning failures. However, my web search findings reveal important context about how MONAI losses work:

  1. MONAI losses support reduction parameters with options: "none", "mean", "sum"
  2. Sub-losses (TverskyLoss, FocalLoss) can compute per-sample losses when initialized with reduction="none"
  3. The original Unified Focal Loss implementation computes per-sample losses that require reduction

The critical issue is that I cannot verify the actual code at lines 265-276 to determine:

  • Whether sub-losses are instantiated with reduction="none" (returning per-sample tensors) or default reduction (returning scalars)
  • Whether line 268's loss variable is actually a scalar or a per-sample tensor

The reviewer's concern is structurally valid (applying torch.sum() and torch.mean() to scalars is indeed a no-op), but the premise cannot be confirmed without code inspection.


Reduction logic requires verification against actual sub-loss instantiation. The concern raised is structurally sound — if sub-losses use reduction="mean" by default, applying torch.sum() or torch.mean() at lines 271–275 would be redundant no-ops on scalars. However, if sub-losses return per-sample tensors (via reduction="none"), this reduction logic is necessary. Verify: (1) how AsymmetricFocalLoss and TverskyLoss are initialized in lines ~260–268, (2) whether their reduction parameter is explicitly set or uses defaults, and (3) the actual shape of loss at line 268.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add sigmoid/softmax interface for AsymmetricUnifiedFocalLoss

1 participant