Skip to content

Conversation

@ytl0623
Copy link
Contributor

@ytl0623 ytl0623 commented Dec 22, 2025

Fixes #8603

Description

Refactors AsymmetricUnifiedFocalLoss and its sub-components (AsymmetricFocalLoss, AsymmetricFocalTverskyLoss) to extend support from Binary-only to Multi-class segmentation, while also fixing mathematical logic errors and parameter passing bugs.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 22, 2025

Walkthrough

Adds a use_softmax option and explicit probability handling to AsymmetricFocalTverskyLoss, AsymmetricFocalLoss, and AsymmetricUnifiedFocalLoss. For single‑channel predictions the code now converts logits to probabilities (sigmoid by default; softmax when use_softmax=True), can expand to two‑class representation, and uses an is_already_prob flag to avoid redundant transforms. Target one‑hot conversion (to_onehot_y) is applied conditionally based on prediction channels. AsymmetricUnifiedFocalLoss now composes AsymmetricFocalLoss and AsymmetricFocalTverskyLoss (exposed as public members) and returns their weighted sum. Tests updated for binary logits, 2‑channel, multiclass perfect/wrong cases, shape errors, and CUDA.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed Title accurately reflects the main change: adding sigmoid/softmax interface support to AsymmetricUnifiedFocalLoss.
Description check ✅ Passed Description covers the key changes (refactoring for multi-class support, fixes) and aligns with the template, though integration tests and docs steps remain incomplete.
Linked Issues check ✅ Passed Changes directly address #8603 requirements: add sigmoid/softmax interface, extend to multi-class, fix logic errors, and reuse existing loss implementations.
Out of Scope Changes check ✅ Passed All changes remain scoped to AsymmetricUnifiedFocalLoss and its sub-components; no unrelated modifications detected.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)

76-82: Add stacklevel=2 to warning.

Per static analysis and Python conventions, set stacklevel to point to caller.

-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)

177-194: Logic is correct; consider documenting return value.

The forward method correctly passes logits to FocalLoss and probabilities to AsymmetricFocalTverskyLoss. Per coding guidelines, docstrings should document return values.

         """
         Args:
             y_pred: (BNH[WD]) Logits (raw scores).
             y_true: (BNH[WD]) Ground truth labels.
+
+        Returns:
+            torch.Tensor: Weighted combination of focal loss and asymmetric focal Tversky loss.
         """
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 15fd428 and c27945a.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
  • FocalLoss (26-202)
monai/networks/utils.py (1)
  • one_hot (170-220)
monai/utils/enums.py (1)
  • LossReduction (253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py

78-78: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


85-85: Avoid specifying long messages outside the exception class

(TRY003)


127-127: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (1)
monai/losses/unified_focal_loss.py (1)

36-44: Tests already cover the unified focal loss implementation.

New tests were added to cover the changes. The PR indicates that test coverage has been implemented, so this concern can be closed.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)

157-157: Replace Chinese comment with English.

-        self.use_softmax = use_softmax  # 儲存參數
+        self.use_softmax = use_softmax

112-114: Numerical instability when dice approaches 1.0.

When dice_class[:, i] equals 1.0, torch.pow(0, -self.gamma) produces infinity, causing NaN gradients.

Proposed fix
-                # Foreground classes: apply focal modulation
-                # Original logic: (1 - dice) * (1 - dice)^(-gamma) -> (1-dice)^(1-gamma)
-                loss_list.append((1 - dice_class[:, i]) * torch.pow(1 - dice_class[:, i], -self.gamma))
+                # Foreground classes: apply focal modulation
+                back_dice = torch.clamp(1 - dice_class[:, i], min=self.epsilon)
+                loss_list.append(back_dice * torch.pow(back_dice, -self.gamma))
🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)

74-80: Background exclusion logic is correct.

The implementation properly handles the single-channel edge case.

Optional: Add stacklevel to warning
-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between c27945a and 1fba9d3.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
  • FocalLoss (26-202)
monai/networks/utils.py (1)
  • one_hot (170-220)
monai/utils/enums.py (1)
  • LossReduction (253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py

76-76: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


83-83: Avoid specifying long messages outside the exception class

(TRY003)


125-125: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (3)
monai/losses/unified_focal_loss.py (3)

34-58: Constructor changes look good.

The include_background parameter addition aligns with MONAI conventions and is properly documented.


159-173: Component initialization is appropriate.

Both loss components are properly configured with consistent parameters.


181-188: Remove redundant to_onehot_y conversion concern.

Both focal_loss and asy_focal_tversky_loss independently apply to_onehot_y conversion when enabled. Since PyTorch's one-hot encoding creates new tensors (not in-place modifications), passing the same y_true reference to both components is safe and functionally correct. No conversion conflicts occur.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
monai/losses/unified_focal_loss.py (2)

76-76: Add stacklevel=2 to warning.

Missing explicit stacklevel for proper traceback.

-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)

Based on static analysis hints.


176-193: Document that y_pred must be logits.

The forward method expects raw logits since FocalLoss applies activation internally, and manual softmax/sigmoid is applied for the Tversky component. The docstring states this (line 179) but could be more explicit about the consequences of passing probabilities instead.

Consider adding a note in the docstring:

     def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
         """
         Args:
-            y_pred: (BNH[WD]) Logits (raw scores).
+            y_pred: (BNH[WD]) Logits (raw scores, not probabilities).
+                Do not pass pre-activated inputs; activation is applied internally.
             y_true: (BNH[WD]) Ground truth labels.
         """
tests/losses/test_unified_focal_loss.py (1)

26-61: Add test coverage for edge cases.

Current tests only cover perfect predictions with zero loss. Missing coverage for:

  • Imperfect predictions (non-zero loss)
  • include_background=False scenarios
  • to_onehot_y=True with integer labels
  • Multi-class softmax with imperfect predictions
Suggested additional test cases
# Case 2: Binary with include_background=False
[
    {
        "use_softmax": False,
        "include_background": False,
    },
    {
        "y_pred": torch.tensor([[[[logit_pos, logit_neg], [logit_neg, logit_pos]]]]),
        "y_true": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]),
    },
    0.0,  # Should still be zero for perfect prediction
],
# Case 3: Multi-class with to_onehot_y=True (integer labels)
[
    {
        "use_softmax": True,
        "include_background": True,
        "to_onehot_y": True,
    },
    {
        "y_pred": torch.tensor([[[[logit_pos, logit_neg], [logit_neg, logit_neg], [logit_neg, logit_pos]]]]),
        "y_true": torch.tensor([[[[0, 2]]]]),  # Integer labels: class 0, class 2
    },
    0.0,
],

Do you want me to generate a complete test case addition or open an issue to track this?

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 1fba9d3 and 39664ea.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (129-193)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
  • FocalLoss (26-202)
monai/networks/utils.py (1)
  • one_hot (170-220)
monai/utils/enums.py (1)
  • LossReduction (253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py

76-76: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


83-83: Avoid specifying long messages outside the exception class

(TRY003)


126-126: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (5)
monai/losses/unified_focal_loss.py (3)

184-187: LGTM: Correct activation choice for Tversky loss.

FocalLoss handles its own activation internally, so this manual conversion to probabilities for AsymmetricFocalTverskyLoss is correct. The activation choice (softmax vs sigmoid) properly follows the use_softmax flag.


89-116: Implementation correctly handles include_background with standard MONAI slicing pattern.

When include_background=False, channel index 0 is excluded from the calculation—the code does this via tensor slicing at lines 79-80 before the asymmetry loop. Once sliced, all remaining channels receive focal modulation; none are treated as background. The loss only supports binary segmentation, so asymmetry designates the first present channel as background and all others as foreground, which is the intended behavior per the documented design comment (lines 101-104).


160-174: Both composed losses independently transform y_true with their respective settings. Each applies its own non-destructive transformations (one-hot encoding creates new tensors; slicing creates new views), so no actual collision occurs. This is correct by design—composed losses should handle their own input transformations.

tests/losses/test_unified_focal_loss.py (2)

22-24: LGTM: High-confidence logits ensure clear test expectations.

Using ±10.0 logits produces near-perfect probabilities (~0.9999 and ~0.0001), making zero-loss expectations reasonable for perfect predictions.


77-89: LGTM: CUDA test correctly instantiates loss.

The test properly moves the loss module to CUDA (line 85), ensuring both model parameters and inputs are on the same device.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)

74-80: Background exclusion correctly implemented.

The logic properly removes the first channel when include_background=False, consistent with FocalLoss. The single-channel warning is appropriate.

Optional: Add stacklevel to warning for better traceability
-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 39664ea and 41dccad.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (2)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
  • FocalLoss (26-202)
monai/networks/utils.py (1)
  • one_hot (170-220)
monai/utils/enums.py (1)
  • LossReduction (253-264)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (129-193)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py

76-76: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


83-83: Avoid specifying long messages outside the exception class

(TRY003)


126-126: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (8)
tests/losses/test_unified_focal_loss.py (3)

22-42: Binary test case is correct.

The high-confidence logits (±10.0) correctly produce near-perfect probabilities after sigmoid. The alignment between predictions and targets should yield near-zero loss.


66-70: Test structure is correct.

Parameterized test properly unpacks configuration and data, with appropriate numerical tolerances for floating-point comparison.


77-89: CUDA test correctly adapted to logits interface.

The test properly uses logits with use_softmax=False for binary segmentation and correctly moves both tensors and the loss module to CUDA.

monai/losses/unified_focal_loss.py (5)

19-19: Import is correct.

FocalLoss is properly imported from monai.losses for reuse in the unified loss.


34-58: Constructor properly extended with include_background.

The parameter is correctly documented, defaulted, and stored for use in the forward method, consistent with MONAI's loss interface patterns.


97-117: Asymmetric focal modulation correctly implemented.

Background class uses standard Dice loss while foreground classes apply focal modulation (1-dice)^(1-gamma). Clamping prevents numerical instability when dice approaches 1.0.


135-174: Composition pattern correctly implemented.

The constructor properly instantiates and configures both FocalLoss and AsymmetricFocalTverskyLoss components with shared parameters, enabling modular loss computation.


176-193: Forward method correctly combines losses.

FocalLoss operates on logits (with internal activation), while AsymmetricFocalTverskyLoss requires probabilities. The explicit softmax/sigmoid conversion for the Tversky component is correct, and the weighted combination is straightforward.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/losses/test_unified_focal_loss.py (1)

26-56: Add test case for include_background=False.

Test cases cover sigmoid/softmax modes correctly, but the include_background parameter (added per PR objectives) is only tested with True. Add a multi-class case with include_background=False to validate background channel exclusion.

Example test case
# Case 2: Multi-class with background excluded
[
    {
        "use_softmax": True,
        "include_background": False,
    },
    {
        "y_pred": torch.tensor([[[[logit_pos, logit_neg]], [[logit_neg, logit_pos]], [[logit_neg, logit_neg]]]]),
        "y_true": torch.tensor([[[[1.0, 0.0]], [[0.0, 1.0]], [[0.0, 0.0]]]]),
    },
    0.0,
],
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 41dccad and ca81e4a.

📒 Files selected for processing (1)
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (1)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (129-193)
🔇 Additional comments (3)
tests/losses/test_unified_focal_loss.py (3)

22-24: LGTM - Clear test constants.

Helper logits are well-defined for creating high-confidence predictions.


62-65: LGTM - Parameterized test structure correct.

Test method properly unpacks config and data dicts.


72-84: LGTM - CUDA test properly implemented.

Test correctly uses logits and moves both tensors and loss module to CUDA.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/losses/test_unified_focal_loss.py (1)

79-82: Add docstring.

Per coding guidelines, add a docstring describing that this test validates shape mismatch error handling.

🧹 Nitpick comments (1)
tests/losses/test_unified_focal_loss.py (1)

26-68: Add at least one test with non-zero loss.

All test cases expect 0.0 loss with perfect predictions. Add a case with imperfect predictions (e.g., logits slightly off from ground truth) to verify the loss is actually computed, not just validating tensor shape compatibility.

Optional: Expand parameter coverage

Consider adding test cases that vary:

  • to_onehot_y=True with class-index format ground truth
  • weight, delta, gamma to non-default values
  • reduction modes (SUM, NONE)

These are optional and can be deferred.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between ca81e4a and c9002e0.

📒 Files selected for processing (1)
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (1)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (129-193)
🔇 Additional comments (1)
tests/losses/test_unified_focal_loss.py (1)

22-24: LGTM—High-confidence logits for perfect-prediction tests.

Values create predictions very close to 0 or 1, suitable for validating near-zero loss on ideal inputs.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tests/losses/test_unified_focal_loss.py (1)

26-68: Test coverage is adequate for the new interface.

The three cases properly exercise binary sigmoid, multi-class softmax, and background exclusion. All use perfect predictions (loss=0.0), which validates the happy path.

Consider adding at least one test case with imperfect predictions (non-zero expected loss) to verify the actual loss computation, not just the zero-loss edge case. This would increase confidence in the refactored implementation.

monai/losses/unified_focal_loss.py (1)

134-173: Excellent refactoring using composition.

Replacing inline logic with composed FocalLoss and AsymmetricFocalTverskyLoss improves maintainability and reusability. Parameters are correctly forwarded to both components.

Per past review feedback, the use_softmax docstring (lines 151-152) could add brief guidance: softmax for mutually exclusive classes (standard multi-class), sigmoid for multi-label/overlapping classes.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between c9002e0 and 2f4657e.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (128-195)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
  • FocalLoss (26-202)
monai/networks/utils.py (1)
  • one_hot (170-220)
monai/utils/enums.py (1)
  • LossReduction (253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py

83-83: Avoid specifying long messages outside the exception class

(TRY003)


125-125: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (9)
tests/losses/test_unified_focal_loss.py (3)

22-24: LGTM.

Helper constants are well-commented and appropriate for generating high-confidence predictions in tests.


74-85: LGTM.

Docstring properly documents parameters. Parameterized test structure is clean and flexible.


92-104: LGTM.

CUDA test correctly uses the new API with sigmoid activation and validates GPU compatibility.

monai/losses/unified_focal_loss.py (6)

19-19: LGTM.

Import required for the new composition-based implementation.


34-58: LGTM.

Adding include_background parameter aligns with MONAI loss function conventions and enables proper multi-class segmentation support.


74-80: LGTM.

Background exclusion logic correctly follows the FocalLoss pattern, including the single-channel warning.


106-114: Asymmetry logic is correct.

Background channel (index 0 when include_background=True) uses standard Dice loss, while foreground channels use focal modulation. When include_background=False, all channels receive focal modulation since background was removed. Clamping prevents numerical instability.


118-125: LGTM.

Reduction logic correctly handles MEAN, SUM, and NONE cases with appropriate error for unsupported values.


175-195: LGTM.

Forward pass correctly handles different input requirements: logits for FocalLoss (which applies activation internally), probabilities for AsymmetricFocalTverskyLoss. The weighted combination is straightforward and matches the documented formula.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/losses/test_unified_focal_loss.py (1)

117-128: Add docstring for CUDA test.

Per coding guidelines and past review comment, add a docstring describing this test's purpose: validating CUDA compatibility with perfect predictions.

Suggested docstring
     def test_with_cuda(self):
+        """Test AsymmetricUnifiedFocalLoss CUDA compatibility with perfect predictions."""
         loss = AsymmetricUnifiedFocalLoss()
🧹 Nitpick comments (3)
tests/losses/test_unified_focal_loss.py (1)

25-93: Suggest adding imperfect prediction test cases.

All three cases test perfect predictions (loss=0.0). Add at least one case with misaligned logits/labels to verify the loss computes non-zero values correctly and gradients flow properly.

Example imperfect case
[  # Case 3: Imperfect prediction
    {"use_softmax": False, "include_background": True},
    {
        "y_pred": torch.tensor([[[[0.0, -2.0], [2.0, 0.0]]]]),  # Moderate confidence
        "y_true": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]),
    },
    # Expected: positive loss value (compute actual expected value)
],
monai/losses/unified_focal_loss.py (2)

60-86: LGTM: Background exclusion logic is correct.

The include_background handling properly slices channel 0 from both tensors and warns on single-channel edge cases. Shape validation and clipping are correctly placed.

Note: Static analysis flags line 83 for a long exception message (TRY003). Consider a custom exception class if this pattern recurs, but current usage is acceptable.


88-125: LGTM: Asymmetric focal Tversky logic is sound.

The per-class loss correctly applies standard Tversky to background (when included) and focal-modulated Tversky to foreground. Clamping prevents numerical instability. Reduction handling is complete.

Static analysis flags line 125 for a long exception message (TRY003). Consider extracting to a constant or custom exception if this pattern is reused.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 2f4657e and e63e36e.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (128-197)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
  • FocalLoss (26-202)
monai/networks/utils.py (1)
  • one_hot (170-220)
monai/utils/enums.py (1)
  • LossReduction (253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py

83-83: Avoid specifying long messages outside the exception class

(TRY003)


125-125: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (6)
tests/losses/test_unified_focal_loss.py (2)

22-23: LGTM: Clear test constants.

Module-level logit constants are well-named and appropriate for testing high-confidence predictions.


99-110: LGTM: Well-documented parameterized test.

Docstring and implementation are clear. Tolerance levels are appropriate.

monai/losses/unified_focal_loss.py (4)

19-19: LGTM: Import supports composition pattern.

FocalLoss import enables the refactored AsymmetricUnifiedFocalLoss to reuse existing focal loss implementation.


34-58: LGTM: Consistent API with include_background parameter.

Addition of include_background aligns with MONAI loss conventions. Default True preserves backward compatibility.


128-175: LGTM: Clean composition of focal components.

Refactoring to compose FocalLoss and AsymmetricFocalTverskyLoss eliminates code duplication and ensures consistent parameter handling. Docstrings clearly distinguish sigmoid vs. softmax use cases.


177-197: LGTM: Forward pass correctly combines loss components.

The focal loss operates on logits while the Tversky component operates on probabilities after explicit activation. Each component independently handles to_onehot_y and include_background, ensuring correct behavior. Weighted combination is straightforward.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)

157-157: Remove Chinese comment.

-        self.use_softmax = use_softmax  # 儲存參數
+        self.use_softmax = use_softmax

This was previously flagged but remains in the code.

🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)

118-125: Reduction logic is correct.

Standard reduction pattern implemented properly.

For consistency with MONAI style, consider extracting the long error message to a constant or shortening it (static analysis hint TRY003):

-        raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
+        raise ValueError(f"Unsupported reduction: {self.reduction}")
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between e63e36e and 1d196dc.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py

76-76: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


83-83: Avoid specifying long messages outside the exception class

(TRY003)


125-125: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (5)
monai/losses/unified_focal_loss.py (5)

19-19: LGTM.

Import required for the new composition-based implementation.


34-59: LGTM.

The include_background parameter is properly documented and maintains backward compatibility with True as default.


74-81: Background exclusion logic is correct.

The warning for single-channel predictions and slicing logic are appropriate.

However, add stacklevel=2 to the warning at line 76 for proper caller identification:

-                warnings.warn("single channel prediction, `include_background=False` ignored.")
+                warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)

Based on static analysis hints.


159-173: Asymmetric gamma application is intentional and correct.

The Unified Focal Loss design intentionally exploits gamma asymmetry to enable simultaneous suppression and enhancement effects in its component losses. In FocalLoss, gamma down-weights easy-to-classify pixels, while in Focal Tversky Loss, gamma enhances rather than suppresses easy examples. Gamma controls weights for difficult-to-predict samples; distribution-based corrections apply sample-by-sample while region-based corrections apply class-by-class during macro-averaging. This composition pattern correctly implements the unified focal loss framework.


175-192: Forward implementation is correct.

The loss properly:

  1. Computes focal loss on logits
  2. Converts logits to probabilities for Tversky component via softmax or sigmoid
  3. Combines losses with configurable weighting

Test coverage includes both sigmoid and softmax activation paths with appropriate input dimensions.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
tests/losses/test_unified_focal_loss.py (1)

93-106: Add docstring.

Per coding guidelines, add a docstring describing the test purpose.

🔎 Proposed fix
     def test_with_cuda(self):
+        """Validate CUDA compatibility of AsymmetricUnifiedFocalLoss."""
         if not torch.cuda.is_available():

Based on coding guidelines.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 1d196dc and f7cad77.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (2)
monai/losses/unified_focal_loss.py (1)
monai/utils/enums.py (1)
  • LossReduction (253-264)
tests/losses/test_unified_focal_loss.py (3)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (229-298)
tests/test_utils.py (1)
  • assert_allclose (119-159)
monai/networks/nets/quicknat.py (1)
  • is_cuda (433-437)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

83-83: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


89-89: Avoid specifying long messages outside the exception class

(TRY003)


188-188: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


194-194: Avoid specifying long messages outside the exception class

(TRY003)


291-291: Avoid specifying long messages outside the exception class

(TRY003)

tests/losses/test_unified_focal_loss.py

83-83: Unused method argument: expected_val

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: build-docs
🔇 Additional comments (7)
tests/losses/test_unified_focal_loss.py (2)

21-66: Test case definitions are well-structured.

The test cases cover binary logits, 2-channel binary, and multi-class scenarios with appropriate shapes and parameter combinations. The use of 10.0/-10.0 logits ensures near-perfect probabilities for validation.


71-80: LGTM.

Tolerance of 1e-3 is appropriate given that logits of ±10.0 don't yield exact probabilities of 0.0/1.0.

monai/losses/unified_focal_loss.py (5)

34-60: LGTM.

The use_softmax parameter is properly integrated with clear documentation.


91-129: Loss calculations are correct.

The background dice and foreground focal-tversky computations align with the paper's formulation. The use of 1/gamma exponent for foreground classes properly implements the focal modulation.


196-226: Focal loss implementation is correct.

The asymmetric weighting (background focal, foreground standard CE) with delta balancing correctly addresses class imbalance.


237-279: Composition pattern is well-executed.

Creating internal loss instances with shared parameters ensures consistency and avoids duplication in the forward pass.


281-298: Forward logic is sound.

The shape validation correctly handles edge cases (binary logits, to_onehot_y), and the weighted combination properly unifies focal and tversky losses.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)

240-240: num_classes parameter is unused.

The num_classes parameter is stored at line 260 but never referenced. Either remove it or use it.

🔎 Proposed fix to remove unused parameter
     def __init__(
         self,
         to_onehot_y: bool = False,
-        num_classes: int = 2,
         weight: float = 0.5,
         gamma: float = 0.5,
         delta: float = 0.7,
         reduction: LossReduction | str = LossReduction.MEAN,
         use_softmax: bool = False,
     ):
         """
         Args:
             to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
-            num_classes: number of classes. Defaults to 2.
             weight: weight factor to balance between Focal Loss and Tversky Loss.

And remove self.num_classes = num_classes at line 260.

Also applies to: 260-260

♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)

83-83: Add stacklevel=2 to warning.

Per static analysis, add stacklevel=2 so the warning points to the caller.

-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)

188-188: Add stacklevel=2 to warning.

Per static analysis, add stacklevel=2.

-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
🧹 Nitpick comments (6)
monai/losses/unified_focal_loss.py (6)

122-129: Unreachable fallback return.

The final return torch.mean(all_losses) at line 129 is unreachable for valid LossReduction values. Consider raising an error for invalid reductions or removing the redundant return.

🔎 Proposed fix
         if self.reduction == LossReduction.MEAN.value:
             return torch.mean(all_losses)
         if self.reduction == LossReduction.SUM.value:
             return torch.sum(all_losses)
         if self.reduction == LossReduction.NONE.value:
             return all_losses
-
-        return torch.mean(all_losses)
+        raise ValueError(f"Unsupported reduction: {self.reduction}")

158-159: Incomplete docstring for reduction parameter.

The reduction parameter docstring is missing its description.

-            reduction: {``"none"``, ``"mean"``, ``"sum"``}
-            use_softmax: whether to use softmax to transform logits. Defaults to False.
+            reduction: {``"none"``, ``"mean"``, ``"sum"``}
+                Specifies the reduction to apply to the output. Defaults to ``"mean"``.
+            use_softmax: whether to use softmax to transform logits. Defaults to False.

175-202: Consider extracting shared preprocessing logic.

Lines 175-202 duplicate the preprocessing from AsymmetricFocalTverskyLoss (lines 70-99): single-channel handling, one-hot conversion, shape validation, and probability conversion. Extract to a shared helper to reduce duplication.


293-294: Sub-losses preprocess inputs independently, causing duplicate work.

Both self.asy_focal_loss and self.asy_focal_tversky_loss independently apply sigmoid/softmax, one-hot encoding, and clamping to the same inputs. For performance, consider preprocessing once in this forward method and passing processed tensors to sub-losses configured to skip preprocessing.


62-68: Docstring should document return value and exceptions.

Per coding guidelines, docstrings should describe return value and raised exceptions.

         """
         Args:
             y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims).
             y_true: ground truth labels. Shape should match y_pred.
+
+        Returns:
+            torch.Tensor: Computed loss. Shape depends on reduction setting.
+
+        Raises:
+            ValueError: If y_true and y_pred shapes don't match after preprocessing.
         """

168-173: Docstring should document return value and exceptions.

Same as AsymmetricFocalTverskyLoss.forward - add Returns and Raises sections.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between f7cad77 and 4166faa.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/networks/utils.py (1)
  • one_hot (170-220)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

83-83: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


89-89: Avoid specifying long messages outside the exception class

(TRY003)


188-188: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


194-194: Avoid specifying long messages outside the exception class

(TRY003)


291-291: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (2)
monai/losses/unified_focal_loss.py (2)

288-291: Shape validation is correct but duplicated.

The shape check here is a reasonable early guard, though sub-losses will validate again after their own preprocessing. Acceptable as-is for fail-fast behavior.


296-298: LGTM!

Clean delegation to sub-losses with proper weighted combination.

Signed-off-by: ytl0623 <david89062388@gmail.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)

237-279: Remove unused num_classes parameter or document for backward compatibility.

The num_classes parameter is stored at line 260 but never used. It's not passed to internal loss instances and not referenced in the forward method. Either remove it or add a comment explaining why it's retained.

♻️ Duplicate comments (4)
monai/losses/unified_focal_loss.py (2)

186-191: Add stacklevel=2 to warning.

The warnings.warn call should specify stacklevel=2 so the warning points to the caller's code.

🔎 Proposed fix
         if self.to_onehot_y:
             if n_pred_ch == 1:
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
             else:

Based on static analysis hint.


81-86: Add stacklevel=2 to warning.

The warnings.warn call should specify stacklevel=2 so the warning points to the caller's code.

🔎 Proposed fix
         if self.to_onehot_y:
             if n_pred_ch == 1:
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
             else:

Based on static analysis hint.

tests/losses/test_unified_focal_loss.py (2)

94-107: Add docstring.

Per coding guidelines, add a docstring describing that this test validates CUDA compatibility of the loss.

🔎 Proposed fix
     def test_with_cuda(self):
+        """Verify CUDA compatibility by running loss on GPU tensors when available."""
         if not torch.cuda.is_available():

83-87: Remove unused parameter and add docstring.

The expected_val parameter is unused. Remove it from the signature and update the test case accordingly. Also add a docstring per coding guidelines.

🔎 Proposed fix
     @parameterized.expand([TEST_CASE_MULTICLASS_WRONG])
-    def test_wrong_prediction(self, input_data, expected_val, args):
+    def test_wrong_prediction(self, input_data, args):
+        """Verify that wrong predictions yield high loss values."""
         loss_func = AsymmetricUnifiedFocalLoss(**args)
         result = loss_func(**input_data)
         self.assertGreater(result.item(), 1.0, "Loss should be high for wrong predictions")

Update TEST_CASE_MULTICLASS_WRONG at line 62 to remove the None value:

 TEST_CASE_MULTICLASS_WRONG = [
     {
         "y_pred": torch.tensor(
             [[[[-10.0, -10.0], [-10.0, -10.0]], [[10.0, 10.0], [10.0, 10.0]], [[-10.0, -10.0], [-10.0, -10.0]]]]
         ),
         "y_true": torch.tensor([[[[0, 0], [0, 0]]]]),  # GT is class 0, but Pred is class 1
     },
-    None,
     {"use_softmax": True, "to_onehot_y": True},
 ]

Based on static analysis hint.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 4166faa and 45d9877.

📒 Files selected for processing (2)
  • monai/losses/unified_focal_loss.py
  • tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/losses/test_unified_focal_loss.py
  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
  • AsymmetricUnifiedFocalLoss (229-298)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/networks/utils.py (1)
  • one_hot (170-220)
🪛 Ruff (0.14.10)
tests/losses/test_unified_focal_loss.py

84-84: Unused method argument: expected_val

(ARG002)

monai/losses/unified_focal_loss.py

83-83: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


89-89: Avoid specifying long messages outside the exception class

(TRY003)


188-188: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


194-194: Avoid specifying long messages outside the exception class

(TRY003)


291-291: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: build-docs
  • GitHub Check: packaging
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.12)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)

81-86: Add stacklevel=2 to warning (still open from past reviews).

🔎 Proposed fix
         if self.to_onehot_y:
             if n_pred_ch == 1:
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
             else:

Based on static analysis hint.


186-191: Add stacklevel=2 to warning (still open from past reviews).

🔎 Proposed fix
         if self.to_onehot_y:
             if n_pred_ch == 1:
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
             else:

Based on static analysis hint.

🧹 Nitpick comments (3)
monai/losses/unified_focal_loss.py (3)

110-129: Loss computation is correct but consider documenting asymmetric treatment.

Background uses standard Dice loss while foreground uses focal modulation. This asymmetry is intentional to prioritize foreground classes, but could benefit from an inline comment for future maintainers.

🔎 Optional documentation enhancement
         # Calculate losses separately for each class
-        # Background: Standard Dice Loss
+        # Background: Standard Dice Loss (no focal modulation to preserve sensitivity)
         back_dice = 1 - dice_class[:, 0]
 
-        # Foreground: Focal Tversky Loss
+        # Foreground: Focal Tversky Loss (focal modulation to down-weight easy examples)
         fore_dice = torch.pow(1 - dice_class[:, 1:], 1 / self.gamma)

175-182: Optional: Extract duplicated single-channel handling to helper function.

The same single-channel auto-conversion logic appears in both AsymmetricFocalTverskyLoss (lines 69-78) and AsymmetricFocalLoss. Consider extracting to a shared helper if more losses adopt this pattern.


266-279: Consider exposing separate gamma parameters for the two loss components.

AsymmetricFocalLoss defaults to gamma=2.0 while AsymmetricFocalTverskyLoss defaults to gamma=0.75, but AsymmetricUnifiedFocalLoss forces both to use the same gamma value. This prevents users from independently tuning focal modulation for distribution-based (CE) vs region-based (Dice) objectives.

Not blocking, but consider adding gamma_focal and gamma_tversky parameters in a future revision if users request finer control.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 45d9877 and 05dac9e.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
  • LossReduction (253-264)
monai/networks/utils.py (1)
  • one_hot (170-220)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

83-83: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


89-89: Avoid specifying long messages outside the exception class

(TRY003)


188-188: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


194-194: Avoid specifying long messages outside the exception class

(TRY003)


291-291: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (3)
monai/losses/unified_focal_loss.py (3)

69-78: Verify single-channel sigmoid conversion produces valid probability distribution.

The auto-handling converts single-channel logits to two-channel probabilities via torch.cat([1 - y_pred, y_pred], dim=1) after sigmoid. This assumes y_pred after sigmoid is the foreground probability. Confirm this matches user expectations and aligns with the rest of MONAI's binary segmentation conventions.


91-99: LGTM - probability conversion logic is correct.

The is_already_prob flag prevents double conversion for single-channel inputs, and clamping protects against numerical instability.


288-291: Shape validation logic is sound.

The check for binary logits case (y_pred.shape[1] == 1 and not self.use_softmax) correctly allows shape mismatch when appropriate.

Comment on lines +213 to +226
# Concatenate losses
all_ce = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1)

# Sum over classes (dim=1) to get total loss per pixel
total_loss = torch.sum(all_ce, dim=1)

class AsymmetricUnifiedFocalLoss(_Loss):
"""
AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
# Apply reduction
if self.reduction == LossReduction.MEAN.value:
return torch.mean(total_loss)
if self.reduction == LossReduction.SUM.value:
return torch.sum(total_loss)
if self.reduction == LossReduction.NONE.value:
return total_loss
return torch.mean(total_loss)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find and examine the unified_focal_loss.py file
find . -name "unified_focal_loss.py" -type f | head -5

Repository: Project-MONAI/MONAI

Length of output: 99


🌐 Web query:

site:raw.githubusercontent.com monai/losses/unified_focal_loss.py

💡 Result:

I couldn't find a file at raw.githubusercontent.com matching "monai/losses/unified_focal_loss.py". However, the Unified Focal Loss is implemented and available in these locations:

  • MONAI repository / loss docs (MONAI provides related loss implementations; see loss docs). [1]
  • Original Unified Focal Loss code (author's repo, TensorFlow/Keras). [2]
  • Unofficial PyTorch implementations referencing MONAI's implementation. [3]

Sources:
[1] MONAI loss functions documentation / Project-MONAI repository. (docs.monai.io)
[2] Unified Focal Loss paper and code (mlyg/unified-focal-loss). (arxiv.org)
[3] Unofficial PyTorch implementation referencing MONAI. (github.com)

If you want, I can:

  • search raw.githubusercontent.com for other possible paths,
  • fetch the MONAI source file directly from the Project-MONAI GitHub and show its contents, or
  • provide the unified focal loss implementation (PyTorch) copied here. Which would you prefer?

Citations:


🏁 Script executed:

#!/bin/bash
# Read the unified_focal_loss.py file to examine the code
wc -l ./monai/losses/unified_focal_loss.py

Repository: Project-MONAI/MONAI

Length of output: 103


🏁 Script executed:

#!/bin/bash
# Read the AsymmetricFocalLoss class and the relevant sections
sed -n '190,230p' ./monai/losses/unified_focal_loss.py

Repository: Project-MONAI/MONAI

Length of output: 1550


🏁 Script executed:

#!/bin/bash
# Read the AsymmetricUnifiedFocalLoss class, especially around line 296
sed -n '270,310p' ./monai/losses/unified_focal_loss.py

Repository: Project-MONAI/MONAI

Length of output: 1328


🏁 Script executed:

#!/bin/bash
# Find AsymmetricFocalTverskyLoss forward method to understand its output shape
grep -n "class AsymmetricFocalTverskyLoss" ./monai/losses/unified_focal_loss.py

Repository: Project-MONAI/MONAI

Length of output: 106


🏁 Script executed:

#!/bin/bash
# Check the test file for reduction=NONE scenarios
grep -n "reduction.*[Nn]one\|NONE" ./tests/test_unified_focal_loss.py 2>/dev/null || echo "Test file not found or no NONE tests"

Repository: Project-MONAI/MONAI

Length of output: 99


🏁 Script executed:

#!/bin/bash
# Read the AsymmetricFocalTverskyLoss forward method to see its output shape
sed -n '23,150p' ./monai/losses/unified_focal_loss.py

Repository: Project-MONAI/MONAI

Length of output: 5311


🏁 Script executed:

#!/bin/bash
# Get a better view of the entire AsymmetricFocalTverskyLoss class
sed -n '23,100p' ./monai/losses/unified_focal_loss.py | head -80

Repository: Project-MONAI/MONAI

Length of output: 3389


🏁 Script executed:

#!/bin/bash
# Check what axis is used in AsymmetricFocalTverskyLoss for reduction
sed -n '85,95p' ./monai/losses/unified_focal_loss.py

Repository: Project-MONAI/MONAI

Length of output: 532


🏁 Script executed:

#!/bin/bash
# Verify the entire test file to confirm no NONE reduction tests
cat ./tests/test_unified_focal_loss.py | grep -i "reduction" | head -20

Repository: Project-MONAI/MONAI

Length of output: 129


🏁 Script executed:

#!/bin/bash
# Find test file location
find . -name "test_unified_focal_loss.py" -type f

Repository: Project-MONAI/MONAI

Length of output: 104


🏁 Script executed:

#!/bin/bash
# Look for tests directory and list test files
ls -la ./tests/ | grep unified

Repository: Project-MONAI/MONAI

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Check the test file
cat ./tests/losses/test_unified_focal_loss.py | grep -i "reduction\|NONE" | head -30

Repository: Project-MONAI/MONAI

Length of output: 72


🏁 Script executed:

#!/bin/bash
# Count lines and get overview of tests
wc -l ./tests/losses/test_unified_focal_loss.py

Repository: Project-MONAI/MONAI

Length of output: 108


🏁 Script executed:

#!/bin/bash
# Check what test cases exist
grep -n "def test_\|reduction" ./tests/losses/test_unified_focal_loss.py

Repository: Project-MONAI/MONAI

Length of output: 274


🏁 Script executed:

#!/bin/bash
# Read the entire test file to understand test coverage
cat ./tests/losses/test_unified_focal_loss.py

Repository: Project-MONAI/MONAI

Length of output: 4265


🏁 Script executed:

#!/bin/bash
# Get the exact line numbers for the forward method of AsymmetricUnifiedFocalLoss
sed -n '280,300p' ./monai/losses/unified_focal_loss.py

Repository: Project-MONAI/MONAI

Length of output: 1000


🏁 Script executed:

#!/bin/bash
# Check the default reduction value in __init__ for AsymmetricUnifiedFocalLoss
sed -n '240,280p' ./monai/losses/unified_focal_loss.py

Repository: Project-MONAI/MONAI

Length of output: 1706


Shape mismatch in AsymmetricUnifiedFocalLoss when reduction=NONE

AsymmetricFocalLoss returns shape (B, H, W, [D]) (per-pixel) with reduction=NONE, while AsymmetricFocalTverskyLoss returns shape (B, C) (per-class). Line 296's addition will fail when both use reduction=NONE.

Recommend either:

  1. Document that reduction=NONE is unsupported for AsymmetricUnifiedFocalLoss
  2. Add runtime check to prevent NONE reduction in AsymmetricUnifiedFocalLoss
  3. Align output shapes so both return matching dimensions

Also add test coverage for reduction=NONE scenario.

🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 213 to 226, the final
concatenation/sum assumes compatible per-pixel vs per-class shapes and will
break when constituent losses return reduction=NONE (AsymmetricFocalLoss returns
per-pixel (B,H,W,...), AsymmetricFocalTverskyLoss returns per-class (B,C)), so
add a runtime guard that rejects reduction==LossReduction.NONE (raise a
ValueError with a clear message that AsymmetricUnifiedFocalLoss does not support
reduction=NONE) or alternatively coerce both component losses to use a supported
reduction (SUM/MEAN) before combining; update the docstring to state NONE is
unsupported for this unified loss and add unit tests asserting a ValueError is
raised when reduction=NONE and tests for SUM/MEAN behavior.

Signed-off-by: ytl0623 <david89062388@gmail.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)

240-240: Remove unused num_classes parameter.

The num_classes parameter is stored but never used in the implementation. It's not passed to internal loss instances and doesn't affect behavior. Either use it to validate inputs or remove it from the interface.

♻️ Duplicate comments (3)
monai/losses/unified_focal_loss.py (3)

83-83: Add stacklevel=2 to warning for proper caller attribution.

The warning call should specify stacklevel=2 to point to the caller's code rather than this line.

🔎 Proposed fix
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)

188-188: Add stacklevel=2 to warning for proper caller attribution.

The warning call should specify stacklevel=2 to point to the caller's code.

🔎 Proposed fix
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
+                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)

293-296: Handle or reject reduction=NONE to prevent shape mismatch.

When reduction=NONE, AsymmetricFocalLoss returns shape (B, H, W, [D]) (per-pixel) while AsymmetricFocalTverskyLoss returns shape (B, C) (per-class). Line 296's addition will fail. Either document and reject NONE reduction with a runtime check, or ensure both losses return compatible shapes.

🔎 Proposed fix - reject NONE reduction
     def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
         """
         Args:
             y_pred: Prediction logits. Shape: (B, C, H, W, [D]).
                     Supports binary (C=1 or C=2) and multi-class (C>2) segmentation.
             y_true: Ground truth labels. Shape should match y_pred (or be indices if to_onehot_y is True).
         """
+        if self.reduction == LossReduction.NONE.value:
+            raise ValueError("AsymmetricUnifiedFocalLoss does not support reduction='none' due to incompatible output shapes from component losses.")
+
         if y_pred.shape != y_true.shape:
🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)

51-53: Clarify sigmoid behavior for multi-channel inputs.

The docstring states sigmoid is used "for binary/multi-label" when use_softmax=False, but doesn't clarify that sigmoid is applied independently to each channel in multi-channel cases. This differs from binary-only behavior where background channel is constructed. Consider adding: "For multi-channel inputs, sigmoid is applied per-channel independently (multi-label)."


159-159: Enhance use_softmax documentation with usage guidance.

Similar to AsymmetricFocalTverskyLoss, the docstring should clarify when to use softmax (mutually exclusive classes) vs. sigmoid (multi-label/overlapping classes). Consider adding: "Use True for mutually exclusive multi-class segmentation, False for binary or multi-label scenarios."

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 05dac9e and cbed38d.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py

Signed-off-by: ytl0623 <david89062388@gmail.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)

175-200: Same redundant probability conversion issue as AsymmetricFocalTverskyLoss.

Lines 175-182 duplicate the auto-handle logic from AsymmetricFocalTverskyLoss with the same flaw: single-channel always uses sigmoid regardless of use_softmax setting.

Apply the same fix as suggested for AsymmetricFocalTverskyLoss to respect use_softmax after expanding to 2-channel.

🧹 Nitpick comments (3)
monai/losses/unified_focal_loss.py (3)

213-226: Reduction logic duplicates default fallback.

Lines 220-226 explicitly handle MEAN/SUM/NONE, then line 226 returns torch.mean(total_loss) as fallback. This fallback is unreachable if LossReduction enum is exhaustive.

🔎 Simplification

Remove redundant fallback or add a warning if an unknown reduction is encountered:

         if self.reduction == LossReduction.MEAN.value:
             return torch.mean(total_loss)
         if self.reduction == LossReduction.SUM.value:
             return torch.sum(total_loss)
         if self.reduction == LossReduction.NONE.value:
             return total_loss
-        return torch.mean(total_loss)
+        raise ValueError(f"Unsupported reduction: {self.reduction}")

Same applies to AsymmetricFocalTverskyLoss lines 122-129.


288-291: Shape validation allows mismatch only for binary logits, but one-hot conversion happens downstream.

Lines 288-291 permit shape mismatch if is_binary_logits (C=1 with sigmoid) or if to_onehot_y=True. However, the internal losses perform one-hot conversion independently. If y_true has mismatched shape and to_onehot_y=False, the internal losses will raise ValueError at their shape checks (lines 89, 194).

This validation is redundant; the internal losses already enforce shape compatibility.

🔎 Simplification

Remove this check and let internal losses handle validation:

-        if y_pred.shape != y_true.shape:
-            is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
-            if not self.to_onehot_y and not is_binary_logits:
-                raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
-

Or add a comment explaining why this pre-validation is needed.


89-89: Long exception messages flagged by static analysis.

Lines 89, 194, and 291 embed long f-string messages directly in ValueError. Ruff (TRY003) suggests defining exception classes or message constants for long messages.

For consistency with MONAI conventions, verify if other loss modules use inline messages or constants. If this pattern is acceptable project-wide, ignore the hint.

Also applies to: 194-194, 291-291

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between cbed38d and b08de65.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py

89-89: Avoid specifying long messages outside the exception class

(TRY003)


194-194: Avoid specifying long messages outside the exception class

(TRY003)


291-291: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (2)
monai/losses/unified_focal_loss.py (2)

266-279: Gamma parameters have opposite semantics in AsymmetricFocalLoss vs AsymmetricFocalTverskyLoss.

AsymmetricFocalLoss uses gamma directly: torch.pow(1 - y_pred, gamma), while AsymmetricFocalTverskyLoss uses its reciprocal: torch.pow(1 - dice_class, 1/gamma). Per the paper, Focal Tversky's optimal gamma=4/3 enhances loss (contrary to Focal loss which suppresses). Passing the same gamma=0.5 to both produces mismatched behaviors and may not match the paper's unified formulation intent.


114-115: The focal modulation formula is correct—it properly focuses on hard examples, not the reverse.

With gamma = 0.75 (default), the exponent 1/gamma = 1.333 > 1, which makes hard examples (low Dice values) contribute more to the loss than easy examples, not less. When you raise numbers to a power greater than 1, small values (easy examples where 1 - dice is small) decrease more than large values (hard examples where 1 - dice is large), so easy examples are down-weighted relative to hard examples. This is standard focal behavior and matches the docstring: "focal exponent value to down-weight easy foreground examples." The Unified Focal Loss specifies γ < 1 increases focusing on harder examples, and MONAI's reparameterization using 1/gamma as the exponent achieves this—gamma = 0.75 yields exponent 1.333, which focuses on hard examples correctly.

Likely an incorrect or invalid review comment.

Comment on lines +69 to +96
# Auto-handle single channel input (binary segmentation case)
if y_pred.shape[1] == 1 and not self.use_softmax:
y_pred = torch.sigmoid(y_pred)
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
is_already_prob = True
if y_true.shape[1] == 1:
y_true = one_hot(y_true, num_classes=2)
else:
is_already_prob = False

n_pred_ch = y_pred.shape[1]

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)
if y_true.shape[1] != n_pred_ch:
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

# clip the prediction to avoid NaN
# Convert logits to probabilities if not already done
if not is_already_prob:
if self.use_softmax:
y_pred = torch.softmax(y_pred, dim=1)
else:
y_pred = torch.sigmoid(y_pred)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Redundant probability conversion when auto-handling single-channel input.

Lines 69-77 apply sigmoid and set is_already_prob=True, then lines 91-96 skip conversion when is_already_prob. However, if use_softmax=True with a single-channel input, the code applies sigmoid at line 71, but softmax would be more appropriate for the 2-channel representation created at line 72.

The logic forces sigmoid for single-channel regardless of use_softmax, which may not align with user intent.

🔎 Suggested approach

Apply the activation chosen by use_softmax after expanding to 2-channel:

-        if y_pred.shape[1] == 1 and not self.use_softmax:
+        if y_pred.shape[1] == 1:
             y_pred = torch.sigmoid(y_pred)
             y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
-            is_already_prob = True
+            if self.use_softmax:
+                y_pred = torch.softmax(y_pred, dim=1)
+            is_already_prob = True
             if y_true.shape[1] == 1:
                 y_true = one_hot(y_true, num_classes=2)
         else:
             is_already_prob = False

Or document that single-channel always uses sigmoid regardless of use_softmax.

🧰 Tools
🪛 Ruff (0.14.10)

89-89: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 69-96, the code currently
always applies sigmoid when auto-handling a single-channel prediction, which is
wrong when use_softmax=True; instead, for single-channel inputs defer
activation: if y_pred has 1 channel and use_softmax is True, expand the logits
to two channels (create a compatible two-channel logits tensor, e.g., by
concatenating the negated and original logits or an appropriate pair) and leave
is_already_prob=False so the later softmax branch runs; if use_softmax is False,
perform the sigmoid, concatenate probabilities to two channels and set
is_already_prob=True; ensure subsequent one-hot conversion and shape checks
remain unchanged.

Comment on lines +110 to +129
# Calculate losses separately for each class
# Background: Standard Dice Loss
back_dice = 1 - dice_class[:, 0]
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)

# Average class scores
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
return loss
# Foreground: Focal Tversky Loss
fore_dice = torch.pow(1 - dice_class[:, 1:], 1 / self.gamma)

# Concatenate background and foreground losses
# back_dice needs unsqueeze to match dimensions: (B,) -> (B, 1)
all_losses = torch.cat([back_dice.unsqueeze(1), fore_dice], dim=1)

# Apply reduction
if self.reduction == LossReduction.MEAN.value:
return torch.mean(all_losses)
if self.reduction == LossReduction.SUM.value:
return torch.sum(all_losses)
if self.reduction == LossReduction.NONE.value:
return all_losses

return torch.mean(all_losses)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Return shape with reduction=NONE is (B, C), not per-pixel.

With reduction=NONE, line 127 returns all_losses with shape (B, C) (per-class). This differs from AsymmetricFocalLoss, which returns per-pixel shape. AsymmetricUnifiedFocalLoss combines both at line 296, causing a shape mismatch when reduction=NONE.

Past review comment flagged this but it remains unresolved.

🔎 Recommended fix

Add runtime guard in AsymmetricUnifiedFocalLoss:

     def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
+        if self.reduction == LossReduction.NONE.value:
+            raise ValueError("AsymmetricUnifiedFocalLoss does not support reduction='none' due to incompatible output shapes from constituent losses.")

Update docstring to document this limitation and add test coverage.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 110 to 129, with
reduction=NONE the function returns all_losses shaped (B, C) (per-class) which
mismatches the per-pixel shape expected by AsymmetricUnifiedFocalLoss; add a
runtime guard to detect reduction==LossReduction.NONE and either (preferable per
the review) raise a clear ValueError explaining that AsymmetricUnifiedFocalLoss
cannot be used with reduction=NONE because this implementation returns per-class
losses, or alternatively change the reduction=NONE branch to compute and return
per-pixel losses to match AsymmetricFocalLoss; update the function docstring to
document this limitation (or the new behavior) and add unit tests that assert
the ValueError is raised (or validate the corrected per-pixel output shape) to
prevent regressions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add sigmoid/softmax interface for AsymmetricUnifiedFocalLoss

1 participant