-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Add sigmoid/softmax interface for AsymmetricUnifiedFocalLoss #8669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
WalkthroughAdds a Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)
76-82: Addstacklevel=2to warning.Per static analysis and Python conventions, set stacklevel to point to caller.
- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
177-194: Logic is correct; consider documenting return value.The forward method correctly passes logits to FocalLoss and probabilities to AsymmetricFocalTverskyLoss. Per coding guidelines, docstrings should document return values.
""" Args: y_pred: (BNH[WD]) Logits (raw scores). y_true: (BNH[WD]) Ground truth labels. + + Returns: + torch.Tensor: Weighted combination of focal loss and asymmetric focal Tversky loss. """
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
FocalLoss(26-202)monai/networks/utils.py (1)
one_hot(170-220)monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py
78-78: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
85-85: Avoid specifying long messages outside the exception class
(TRY003)
127-127: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (1)
monai/losses/unified_focal_loss.py (1)
36-44: Tests already cover the unified focal loss implementation.New tests were added to cover the changes. The PR indicates that test coverage has been implemented, so this concern can be closed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)
157-157: Replace Chinese comment with English.- self.use_softmax = use_softmax # 儲存參數 + self.use_softmax = use_softmax
112-114: Numerical instability when dice approaches 1.0.When
dice_class[:, i]equals 1.0,torch.pow(0, -self.gamma)produces infinity, causing NaN gradients.Proposed fix
- # Foreground classes: apply focal modulation - # Original logic: (1 - dice) * (1 - dice)^(-gamma) -> (1-dice)^(1-gamma) - loss_list.append((1 - dice_class[:, i]) * torch.pow(1 - dice_class[:, i], -self.gamma)) + # Foreground classes: apply focal modulation + back_dice = torch.clamp(1 - dice_class[:, i], min=self.epsilon) + loss_list.append(back_dice * torch.pow(back_dice, -self.gamma))
🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)
74-80: Background exclusion logic is correct.The implementation properly handles the single-channel edge case.
Optional: Add stacklevel to warning
- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
FocalLoss(26-202)monai/networks/utils.py (1)
one_hot(170-220)monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py
76-76: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
83-83: Avoid specifying long messages outside the exception class
(TRY003)
125-125: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (3)
monai/losses/unified_focal_loss.py (3)
34-58: Constructor changes look good.The
include_backgroundparameter addition aligns with MONAI conventions and is properly documented.
159-173: Component initialization is appropriate.Both loss components are properly configured with consistent parameters.
181-188: Remove redundantto_onehot_yconversion concern.Both
focal_lossandasy_focal_tversky_lossindependently applyto_onehot_yconversion when enabled. Since PyTorch's one-hot encoding creates new tensors (not in-place modifications), passing the samey_truereference to both components is safe and functionally correct. No conversion conflicts occur.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
monai/losses/unified_focal_loss.py (2)
76-76: Addstacklevel=2to warning.Missing explicit stacklevel for proper traceback.
- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)Based on static analysis hints.
176-193: Document thaty_predmust be logits.The forward method expects raw logits since
FocalLossapplies activation internally, and manual softmax/sigmoid is applied for the Tversky component. The docstring states this (line 179) but could be more explicit about the consequences of passing probabilities instead.Consider adding a note in the docstring:
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: - y_pred: (BNH[WD]) Logits (raw scores). + y_pred: (BNH[WD]) Logits (raw scores, not probabilities). + Do not pass pre-activated inputs; activation is applied internally. y_true: (BNH[WD]) Ground truth labels. """tests/losses/test_unified_focal_loss.py (1)
26-61: Add test coverage for edge cases.Current tests only cover perfect predictions with zero loss. Missing coverage for:
- Imperfect predictions (non-zero loss)
include_background=Falsescenariosto_onehot_y=Truewith integer labels- Multi-class softmax with imperfect predictions
Suggested additional test cases
# Case 2: Binary with include_background=False [ { "use_softmax": False, "include_background": False, }, { "y_pred": torch.tensor([[[[logit_pos, logit_neg], [logit_neg, logit_pos]]]]), "y_true": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]), }, 0.0, # Should still be zero for perfect prediction ], # Case 3: Multi-class with to_onehot_y=True (integer labels) [ { "use_softmax": True, "include_background": True, "to_onehot_y": True, }, { "y_pred": torch.tensor([[[[logit_pos, logit_neg], [logit_neg, logit_neg], [logit_neg, logit_pos]]]]), "y_true": torch.tensor([[[[0, 2]]]]), # Integer labels: class 0, class 2 }, 0.0, ],Do you want me to generate a complete test case addition or open an issue to track this?
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_unified_focal_loss.pymonai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(129-193)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
FocalLoss(26-202)monai/networks/utils.py (1)
one_hot(170-220)monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py
76-76: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
83-83: Avoid specifying long messages outside the exception class
(TRY003)
126-126: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (5)
monai/losses/unified_focal_loss.py (3)
184-187: LGTM: Correct activation choice for Tversky loss.FocalLoss handles its own activation internally, so this manual conversion to probabilities for AsymmetricFocalTverskyLoss is correct. The activation choice (softmax vs sigmoid) properly follows the
use_softmaxflag.
89-116: Implementation correctly handlesinclude_backgroundwith standard MONAI slicing pattern.When
include_background=False, channel index 0 is excluded from the calculation—the code does this via tensor slicing at lines 79-80 before the asymmetry loop. Once sliced, all remaining channels receive focal modulation; none are treated as background. The loss only supports binary segmentation, so asymmetry designates the first present channel as background and all others as foreground, which is the intended behavior per the documented design comment (lines 101-104).
160-174: Both composed losses independently transformy_truewith their respective settings. Each applies its own non-destructive transformations (one-hot encoding creates new tensors; slicing creates new views), so no actual collision occurs. This is correct by design—composed losses should handle their own input transformations.tests/losses/test_unified_focal_loss.py (2)
22-24: LGTM: High-confidence logits ensure clear test expectations.Using ±10.0 logits produces near-perfect probabilities (~0.9999 and ~0.0001), making zero-loss expectations reasonable for perfect predictions.
77-89: LGTM: CUDA test correctly instantiates loss.The test properly moves the loss module to CUDA (line 85), ensuring both model parameters and inputs are on the same device.
39664ea to
41dccad
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)
74-80: Background exclusion correctly implemented.The logic properly removes the first channel when
include_background=False, consistent with FocalLoss. The single-channel warning is appropriate.Optional: Add stacklevel to warning for better traceability
- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (2)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
FocalLoss(26-202)monai/networks/utils.py (1)
one_hot(170-220)monai/utils/enums.py (1)
LossReduction(253-264)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(129-193)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py
76-76: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
83-83: Avoid specifying long messages outside the exception class
(TRY003)
126-126: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (8)
tests/losses/test_unified_focal_loss.py (3)
22-42: Binary test case is correct.The high-confidence logits (±10.0) correctly produce near-perfect probabilities after sigmoid. The alignment between predictions and targets should yield near-zero loss.
66-70: Test structure is correct.Parameterized test properly unpacks configuration and data, with appropriate numerical tolerances for floating-point comparison.
77-89: CUDA test correctly adapted to logits interface.The test properly uses logits with
use_softmax=Falsefor binary segmentation and correctly moves both tensors and the loss module to CUDA.monai/losses/unified_focal_loss.py (5)
19-19: Import is correct.FocalLoss is properly imported from monai.losses for reuse in the unified loss.
34-58: Constructor properly extended with include_background.The parameter is correctly documented, defaulted, and stored for use in the forward method, consistent with MONAI's loss interface patterns.
97-117: Asymmetric focal modulation correctly implemented.Background class uses standard Dice loss while foreground classes apply focal modulation
(1-dice)^(1-gamma). Clamping prevents numerical instability when dice approaches 1.0.
135-174: Composition pattern correctly implemented.The constructor properly instantiates and configures both FocalLoss and AsymmetricFocalTverskyLoss components with shared parameters, enabling modular loss computation.
176-193: Forward method correctly combines losses.FocalLoss operates on logits (with internal activation), while AsymmetricFocalTverskyLoss requires probabilities. The explicit softmax/sigmoid conversion for the Tversky component is correct, and the weighted combination is straightforward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/losses/test_unified_focal_loss.py (1)
26-56: Add test case forinclude_background=False.Test cases cover sigmoid/softmax modes correctly, but the
include_backgroundparameter (added per PR objectives) is only tested withTrue. Add a multi-class case withinclude_background=Falseto validate background channel exclusion.Example test case
# Case 2: Multi-class with background excluded [ { "use_softmax": True, "include_background": False, }, { "y_pred": torch.tensor([[[[logit_pos, logit_neg]], [[logit_neg, logit_pos]], [[logit_neg, logit_neg]]]]), "y_true": torch.tensor([[[[1.0, 0.0]], [[0.0, 1.0]], [[0.0, 0.0]]]]), }, 0.0, ],
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (1)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(129-193)
🔇 Additional comments (3)
tests/losses/test_unified_focal_loss.py (3)
22-24: LGTM - Clear test constants.Helper logits are well-defined for creating high-confidence predictions.
62-65: LGTM - Parameterized test structure correct.Test method properly unpacks config and data dicts.
72-84: LGTM - CUDA test properly implemented.Test correctly uses logits and moves both tensors and loss module to CUDA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/losses/test_unified_focal_loss.py (1)
79-82: Add docstring.Per coding guidelines, add a docstring describing that this test validates shape mismatch error handling.
🧹 Nitpick comments (1)
tests/losses/test_unified_focal_loss.py (1)
26-68: Add at least one test with non-zero loss.All test cases expect 0.0 loss with perfect predictions. Add a case with imperfect predictions (e.g., logits slightly off from ground truth) to verify the loss is actually computed, not just validating tensor shape compatibility.
Optional: Expand parameter coverage
Consider adding test cases that vary:
to_onehot_y=Truewith class-index format ground truthweight,delta,gammato non-default valuesreductionmodes (SUM, NONE)These are optional and can be deferred.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (1)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(129-193)
🔇 Additional comments (1)
tests/losses/test_unified_focal_loss.py (1)
22-24: LGTM—High-confidence logits for perfect-prediction tests.Values create predictions very close to 0 or 1, suitable for validating near-zero loss on ideal inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tests/losses/test_unified_focal_loss.py (1)
26-68: Test coverage is adequate for the new interface.The three cases properly exercise binary sigmoid, multi-class softmax, and background exclusion. All use perfect predictions (loss=0.0), which validates the happy path.
Consider adding at least one test case with imperfect predictions (non-zero expected loss) to verify the actual loss computation, not just the zero-loss edge case. This would increase confidence in the refactored implementation.
monai/losses/unified_focal_loss.py (1)
134-173: Excellent refactoring using composition.Replacing inline logic with composed
FocalLossandAsymmetricFocalTverskyLossimproves maintainability and reusability. Parameters are correctly forwarded to both components.Per past review feedback, the
use_softmaxdocstring (lines 151-152) could add brief guidance: softmax for mutually exclusive classes (standard multi-class), sigmoid for multi-label/overlapping classes.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_unified_focal_loss.pymonai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(128-195)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
FocalLoss(26-202)monai/networks/utils.py (1)
one_hot(170-220)monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py
83-83: Avoid specifying long messages outside the exception class
(TRY003)
125-125: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (9)
tests/losses/test_unified_focal_loss.py (3)
22-24: LGTM.Helper constants are well-commented and appropriate for generating high-confidence predictions in tests.
74-85: LGTM.Docstring properly documents parameters. Parameterized test structure is clean and flexible.
92-104: LGTM.CUDA test correctly uses the new API with sigmoid activation and validates GPU compatibility.
monai/losses/unified_focal_loss.py (6)
19-19: LGTM.Import required for the new composition-based implementation.
34-58: LGTM.Adding
include_backgroundparameter aligns with MONAI loss function conventions and enables proper multi-class segmentation support.
74-80: LGTM.Background exclusion logic correctly follows the FocalLoss pattern, including the single-channel warning.
106-114: Asymmetry logic is correct.Background channel (index 0 when
include_background=True) uses standard Dice loss, while foreground channels use focal modulation. Wheninclude_background=False, all channels receive focal modulation since background was removed. Clamping prevents numerical instability.
118-125: LGTM.Reduction logic correctly handles MEAN, SUM, and NONE cases with appropriate error for unsupported values.
175-195: LGTM.Forward pass correctly handles different input requirements: logits for
FocalLoss(which applies activation internally), probabilities forAsymmetricFocalTverskyLoss. The weighted combination is straightforward and matches the documented formula.
2f4657e to
e63e36e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/losses/test_unified_focal_loss.py (1)
117-128: Add docstring for CUDA test.Per coding guidelines and past review comment, add a docstring describing this test's purpose: validating CUDA compatibility with perfect predictions.
Suggested docstring
def test_with_cuda(self): + """Test AsymmetricUnifiedFocalLoss CUDA compatibility with perfect predictions.""" loss = AsymmetricUnifiedFocalLoss()
🧹 Nitpick comments (3)
tests/losses/test_unified_focal_loss.py (1)
25-93: Suggest adding imperfect prediction test cases.All three cases test perfect predictions (loss=0.0). Add at least one case with misaligned logits/labels to verify the loss computes non-zero values correctly and gradients flow properly.
Example imperfect case
[ # Case 3: Imperfect prediction {"use_softmax": False, "include_background": True}, { "y_pred": torch.tensor([[[[0.0, -2.0], [2.0, 0.0]]]]), # Moderate confidence "y_true": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]), }, # Expected: positive loss value (compute actual expected value) ],monai/losses/unified_focal_loss.py (2)
60-86: LGTM: Background exclusion logic is correct.The
include_backgroundhandling properly slices channel 0 from both tensors and warns on single-channel edge cases. Shape validation and clipping are correctly placed.Note: Static analysis flags line 83 for a long exception message (TRY003). Consider a custom exception class if this pattern recurs, but current usage is acceptable.
88-125: LGTM: Asymmetric focal Tversky logic is sound.The per-class loss correctly applies standard Tversky to background (when included) and focal-modulated Tversky to foreground. Clamping prevents numerical instability. Reduction handling is complete.
Static analysis flags line 125 for a long exception message (TRY003). Consider extracting to a constant or custom exception if this pattern is reused.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_unified_focal_loss.pymonai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(128-197)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
FocalLoss(26-202)monai/networks/utils.py (1)
one_hot(170-220)monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py
83-83: Avoid specifying long messages outside the exception class
(TRY003)
125-125: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (6)
tests/losses/test_unified_focal_loss.py (2)
22-23: LGTM: Clear test constants.Module-level logit constants are well-named and appropriate for testing high-confidence predictions.
99-110: LGTM: Well-documented parameterized test.Docstring and implementation are clear. Tolerance levels are appropriate.
monai/losses/unified_focal_loss.py (4)
19-19: LGTM: Import supports composition pattern.FocalLoss import enables the refactored AsymmetricUnifiedFocalLoss to reuse existing focal loss implementation.
34-58: LGTM: Consistent API with include_background parameter.Addition of
include_backgroundaligns with MONAI loss conventions. DefaultTruepreserves backward compatibility.
128-175: LGTM: Clean composition of focal components.Refactoring to compose
FocalLossandAsymmetricFocalTverskyLosseliminates code duplication and ensures consistent parameter handling. Docstrings clearly distinguish sigmoid vs. softmax use cases.
177-197: LGTM: Forward pass correctly combines loss components.The focal loss operates on logits while the Tversky component operates on probabilities after explicit activation. Each component independently handles
to_onehot_yandinclude_background, ensuring correct behavior. Weighted combination is straightforward.
e63e36e to
ccc5459
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)
157-157: Remove Chinese comment.- self.use_softmax = use_softmax # 儲存參數 + self.use_softmax = use_softmaxThis was previously flagged but remains in the code.
🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)
118-125: Reduction logic is correct.Standard reduction pattern implemented properly.
For consistency with MONAI style, consider extracting the long error message to a constant or shortening it (static analysis hint TRY003):
- raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + raise ValueError(f"Unsupported reduction: {self.reduction}")
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py
76-76: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
83-83: Avoid specifying long messages outside the exception class
(TRY003)
125-125: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (5)
monai/losses/unified_focal_loss.py (5)
19-19: LGTM.Import required for the new composition-based implementation.
34-59: LGTM.The
include_backgroundparameter is properly documented and maintains backward compatibility withTrueas default.
74-81: Background exclusion logic is correct.The warning for single-channel predictions and slicing logic are appropriate.
However, add
stacklevel=2to the warning at line 76 for proper caller identification:- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)Based on static analysis hints.
159-173: Asymmetric gamma application is intentional and correct.The Unified Focal Loss design intentionally exploits gamma asymmetry to enable simultaneous suppression and enhancement effects in its component losses. In FocalLoss, gamma down-weights easy-to-classify pixels, while in Focal Tversky Loss, gamma enhances rather than suppresses easy examples. Gamma controls weights for difficult-to-predict samples; distribution-based corrections apply sample-by-sample while region-based corrections apply class-by-class during macro-averaging. This composition pattern correctly implements the unified focal loss framework.
175-192: Forward implementation is correct.The loss properly:
- Computes focal loss on logits
- Converts logits to probabilities for Tversky component via softmax or sigmoid
- Combines losses with configurable weighting
Test coverage includes both sigmoid and softmax activation paths with appropriate input dimensions.
1d196dc to
edb01ce
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (1)
tests/losses/test_unified_focal_loss.py (1)
93-106: Add docstring.Per coding guidelines, add a docstring describing the test purpose.
🔎 Proposed fix
def test_with_cuda(self): + """Validate CUDA compatibility of AsymmetricUnifiedFocalLoss.""" if not torch.cuda.is_available():Based on coding guidelines.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (2)
monai/losses/unified_focal_loss.py (1)
monai/utils/enums.py (1)
LossReduction(253-264)
tests/losses/test_unified_focal_loss.py (3)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(229-298)tests/test_utils.py (1)
assert_allclose(119-159)monai/networks/nets/quicknat.py (1)
is_cuda(433-437)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py
83-83: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
89-89: Avoid specifying long messages outside the exception class
(TRY003)
188-188: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
194-194: Avoid specifying long messages outside the exception class
(TRY003)
291-291: Avoid specifying long messages outside the exception class
(TRY003)
tests/losses/test_unified_focal_loss.py
83-83: Unused method argument: expected_val
(ARG002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: build-docs
🔇 Additional comments (7)
tests/losses/test_unified_focal_loss.py (2)
21-66: Test case definitions are well-structured.The test cases cover binary logits, 2-channel binary, and multi-class scenarios with appropriate shapes and parameter combinations. The use of 10.0/-10.0 logits ensures near-perfect probabilities for validation.
71-80: LGTM.Tolerance of 1e-3 is appropriate given that logits of ±10.0 don't yield exact probabilities of 0.0/1.0.
monai/losses/unified_focal_loss.py (5)
34-60: LGTM.The
use_softmaxparameter is properly integrated with clear documentation.
91-129: Loss calculations are correct.The background dice and foreground focal-tversky computations align with the paper's formulation. The use of
1/gammaexponent for foreground classes properly implements the focal modulation.
196-226: Focal loss implementation is correct.The asymmetric weighting (background focal, foreground standard CE) with
deltabalancing correctly addresses class imbalance.
237-279: Composition pattern is well-executed.Creating internal loss instances with shared parameters ensures consistency and avoids duplication in the forward pass.
281-298: Forward logic is sound.The shape validation correctly handles edge cases (binary logits, to_onehot_y), and the weighted combination properly unifies focal and tversky losses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)
240-240:num_classesparameter is unused.The
num_classesparameter is stored at line 260 but never referenced. Either remove it or use it.🔎 Proposed fix to remove unused parameter
def __init__( self, to_onehot_y: bool = False, - num_classes: int = 2, weight: float = 0.5, gamma: float = 0.5, delta: float = 0.7, reduction: LossReduction | str = LossReduction.MEAN, use_softmax: bool = False, ): """ Args: to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. - num_classes: number of classes. Defaults to 2. weight: weight factor to balance between Focal Loss and Tversky Loss.And remove
self.num_classes = num_classesat line 260.Also applies to: 260-260
♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)
83-83: Addstacklevel=2to warning.Per static analysis, add
stacklevel=2so the warning points to the caller.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
188-188: Addstacklevel=2to warning.Per static analysis, add
stacklevel=2.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
🧹 Nitpick comments (6)
monai/losses/unified_focal_loss.py (6)
122-129: Unreachable fallback return.The final
return torch.mean(all_losses)at line 129 is unreachable for validLossReductionvalues. Consider raising an error for invalid reductions or removing the redundant return.🔎 Proposed fix
if self.reduction == LossReduction.MEAN.value: return torch.mean(all_losses) if self.reduction == LossReduction.SUM.value: return torch.sum(all_losses) if self.reduction == LossReduction.NONE.value: return all_losses - - return torch.mean(all_losses) + raise ValueError(f"Unsupported reduction: {self.reduction}")
158-159: Incomplete docstring forreductionparameter.The
reductionparameter docstring is missing its description.- reduction: {``"none"``, ``"mean"``, ``"sum"``} - use_softmax: whether to use softmax to transform logits. Defaults to False. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + use_softmax: whether to use softmax to transform logits. Defaults to False.
175-202: Consider extracting shared preprocessing logic.Lines 175-202 duplicate the preprocessing from
AsymmetricFocalTverskyLoss(lines 70-99): single-channel handling, one-hot conversion, shape validation, and probability conversion. Extract to a shared helper to reduce duplication.
293-294: Sub-losses preprocess inputs independently, causing duplicate work.Both
self.asy_focal_lossandself.asy_focal_tversky_lossindependently apply sigmoid/softmax, one-hot encoding, and clamping to the same inputs. For performance, consider preprocessing once in thisforwardmethod and passing processed tensors to sub-losses configured to skip preprocessing.
62-68: Docstring should document return value and exceptions.Per coding guidelines, docstrings should describe return value and raised exceptions.
""" Args: y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims). y_true: ground truth labels. Shape should match y_pred. + + Returns: + torch.Tensor: Computed loss. Shape depends on reduction setting. + + Raises: + ValueError: If y_true and y_pred shapes don't match after preprocessing. """
168-173: Docstring should document return value and exceptions.Same as
AsymmetricFocalTverskyLoss.forward- add Returns and Raises sections.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
LossReduction(253-264)monai/networks/utils.py (1)
one_hot(170-220)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py
83-83: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
89-89: Avoid specifying long messages outside the exception class
(TRY003)
188-188: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
194-194: Avoid specifying long messages outside the exception class
(TRY003)
291-291: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (2)
monai/losses/unified_focal_loss.py (2)
288-291: Shape validation is correct but duplicated.The shape check here is a reasonable early guard, though sub-losses will validate again after their own preprocessing. Acceptable as-is for fail-fast behavior.
296-298: LGTM!Clean delegation to sub-losses with proper weighted combination.
Signed-off-by: ytl0623 <david89062388@gmail.com>
4166faa to
45d9877
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)
237-279: Remove unusednum_classesparameter or document for backward compatibility.The
num_classesparameter is stored at line 260 but never used. It's not passed to internal loss instances and not referenced in the forward method. Either remove it or add a comment explaining why it's retained.
♻️ Duplicate comments (4)
monai/losses/unified_focal_loss.py (2)
186-191: Addstacklevel=2to warning.The
warnings.warncall should specifystacklevel=2so the warning points to the caller's code.🔎 Proposed fix
if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else:Based on static analysis hint.
81-86: Addstacklevel=2to warning.The
warnings.warncall should specifystacklevel=2so the warning points to the caller's code.🔎 Proposed fix
if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else:Based on static analysis hint.
tests/losses/test_unified_focal_loss.py (2)
94-107: Add docstring.Per coding guidelines, add a docstring describing that this test validates CUDA compatibility of the loss.
🔎 Proposed fix
def test_with_cuda(self): + """Verify CUDA compatibility by running loss on GPU tensors when available.""" if not torch.cuda.is_available():
83-87: Remove unused parameter and add docstring.The
expected_valparameter is unused. Remove it from the signature and update the test case accordingly. Also add a docstring per coding guidelines.🔎 Proposed fix
@parameterized.expand([TEST_CASE_MULTICLASS_WRONG]) - def test_wrong_prediction(self, input_data, expected_val, args): + def test_wrong_prediction(self, input_data, args): + """Verify that wrong predictions yield high loss values.""" loss_func = AsymmetricUnifiedFocalLoss(**args) result = loss_func(**input_data) self.assertGreater(result.item(), 1.0, "Loss should be high for wrong predictions")Update TEST_CASE_MULTICLASS_WRONG at line 62 to remove the None value:
TEST_CASE_MULTICLASS_WRONG = [ { "y_pred": torch.tensor( [[[[-10.0, -10.0], [-10.0, -10.0]], [[10.0, 10.0], [10.0, 10.0]], [[-10.0, -10.0], [-10.0, -10.0]]]] ), "y_true": torch.tensor([[[[0, 0], [0, 0]]]]), # GT is class 0, but Pred is class 1 }, - None, {"use_softmax": True, "to_onehot_y": True}, ]Based on static analysis hint.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_unified_focal_loss.pymonai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(229-298)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
LossReduction(253-264)monai/networks/utils.py (1)
one_hot(170-220)
🪛 Ruff (0.14.10)
tests/losses/test_unified_focal_loss.py
84-84: Unused method argument: expected_val
(ARG002)
monai/losses/unified_focal_loss.py
83-83: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
89-89: Avoid specifying long messages outside the exception class
(TRY003)
188-188: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
194-194: Avoid specifying long messages outside the exception class
(TRY003)
291-291: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.12)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)
81-86: Addstacklevel=2to warning (still open from past reviews).🔎 Proposed fix
if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else:Based on static analysis hint.
186-191: Addstacklevel=2to warning (still open from past reviews).🔎 Proposed fix
if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else:Based on static analysis hint.
🧹 Nitpick comments (3)
monai/losses/unified_focal_loss.py (3)
110-129: Loss computation is correct but consider documenting asymmetric treatment.Background uses standard Dice loss while foreground uses focal modulation. This asymmetry is intentional to prioritize foreground classes, but could benefit from an inline comment for future maintainers.
🔎 Optional documentation enhancement
# Calculate losses separately for each class - # Background: Standard Dice Loss + # Background: Standard Dice Loss (no focal modulation to preserve sensitivity) back_dice = 1 - dice_class[:, 0] - # Foreground: Focal Tversky Loss + # Foreground: Focal Tversky Loss (focal modulation to down-weight easy examples) fore_dice = torch.pow(1 - dice_class[:, 1:], 1 / self.gamma)
175-182: Optional: Extract duplicated single-channel handling to helper function.The same single-channel auto-conversion logic appears in both AsymmetricFocalTverskyLoss (lines 69-78) and AsymmetricFocalLoss. Consider extracting to a shared helper if more losses adopt this pattern.
266-279: Consider exposing separate gamma parameters for the two loss components.AsymmetricFocalLoss defaults to
gamma=2.0while AsymmetricFocalTverskyLoss defaults togamma=0.75, but AsymmetricUnifiedFocalLoss forces both to use the samegammavalue. This prevents users from independently tuning focal modulation for distribution-based (CE) vs region-based (Dice) objectives.Not blocking, but consider adding
gamma_focalandgamma_tverskyparameters in a future revision if users request finer control.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
LossReduction(253-264)monai/networks/utils.py (1)
one_hot(170-220)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py
83-83: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
89-89: Avoid specifying long messages outside the exception class
(TRY003)
188-188: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
194-194: Avoid specifying long messages outside the exception class
(TRY003)
291-291: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (3)
monai/losses/unified_focal_loss.py (3)
69-78: Verify single-channel sigmoid conversion produces valid probability distribution.The auto-handling converts single-channel logits to two-channel probabilities via
torch.cat([1 - y_pred, y_pred], dim=1)after sigmoid. This assumesy_predafter sigmoid is the foreground probability. Confirm this matches user expectations and aligns with the rest of MONAI's binary segmentation conventions.
91-99: LGTM - probability conversion logic is correct.The
is_already_probflag prevents double conversion for single-channel inputs, and clamping protects against numerical instability.
288-291: Shape validation logic is sound.The check for binary logits case (
y_pred.shape[1] == 1 and not self.use_softmax) correctly allows shape mismatch when appropriate.
| # Concatenate losses | ||
| all_ce = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) | ||
|
|
||
| # Sum over classes (dim=1) to get total loss per pixel | ||
| total_loss = torch.sum(all_ce, dim=1) | ||
|
|
||
| class AsymmetricUnifiedFocalLoss(_Loss): | ||
| """ | ||
| AsymmetricUnifiedFocalLoss is a variant of Focal Loss. | ||
| # Apply reduction | ||
| if self.reduction == LossReduction.MEAN.value: | ||
| return torch.mean(total_loss) | ||
| if self.reduction == LossReduction.SUM.value: | ||
| return torch.sum(total_loss) | ||
| if self.reduction == LossReduction.NONE.value: | ||
| return total_loss | ||
| return torch.mean(total_loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Find and examine the unified_focal_loss.py file
find . -name "unified_focal_loss.py" -type f | head -5Repository: Project-MONAI/MONAI
Length of output: 99
🌐 Web query:
site:raw.githubusercontent.com monai/losses/unified_focal_loss.py
💡 Result:
I couldn't find a file at raw.githubusercontent.com matching "monai/losses/unified_focal_loss.py". However, the Unified Focal Loss is implemented and available in these locations:
- MONAI repository / loss docs (MONAI provides related loss implementations; see loss docs). [1]
- Original Unified Focal Loss code (author's repo, TensorFlow/Keras). [2]
- Unofficial PyTorch implementations referencing MONAI's implementation. [3]
Sources:
[1] MONAI loss functions documentation / Project-MONAI repository. (docs.monai.io)
[2] Unified Focal Loss paper and code (mlyg/unified-focal-loss). (arxiv.org)
[3] Unofficial PyTorch implementation referencing MONAI. (github.com)
If you want, I can:
- search raw.githubusercontent.com for other possible paths,
- fetch the MONAI source file directly from the Project-MONAI GitHub and show its contents, or
- provide the unified focal loss implementation (PyTorch) copied here. Which would you prefer?
Citations:
- 1: https://docs.monai.io/en/1.5.1/losses.html?utm_source=openai
- 2: https://arxiv.org/abs/2102.04525?utm_source=openai
- 3: https://github.com/JohnMasoner/unified-focal-loss-pytorch?utm_source=openai
🏁 Script executed:
#!/bin/bash
# Read the unified_focal_loss.py file to examine the code
wc -l ./monai/losses/unified_focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 103
🏁 Script executed:
#!/bin/bash
# Read the AsymmetricFocalLoss class and the relevant sections
sed -n '190,230p' ./monai/losses/unified_focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 1550
🏁 Script executed:
#!/bin/bash
# Read the AsymmetricUnifiedFocalLoss class, especially around line 296
sed -n '270,310p' ./monai/losses/unified_focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 1328
🏁 Script executed:
#!/bin/bash
# Find AsymmetricFocalTverskyLoss forward method to understand its output shape
grep -n "class AsymmetricFocalTverskyLoss" ./monai/losses/unified_focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 106
🏁 Script executed:
#!/bin/bash
# Check the test file for reduction=NONE scenarios
grep -n "reduction.*[Nn]one\|NONE" ./tests/test_unified_focal_loss.py 2>/dev/null || echo "Test file not found or no NONE tests"Repository: Project-MONAI/MONAI
Length of output: 99
🏁 Script executed:
#!/bin/bash
# Read the AsymmetricFocalTverskyLoss forward method to see its output shape
sed -n '23,150p' ./monai/losses/unified_focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 5311
🏁 Script executed:
#!/bin/bash
# Get a better view of the entire AsymmetricFocalTverskyLoss class
sed -n '23,100p' ./monai/losses/unified_focal_loss.py | head -80Repository: Project-MONAI/MONAI
Length of output: 3389
🏁 Script executed:
#!/bin/bash
# Check what axis is used in AsymmetricFocalTverskyLoss for reduction
sed -n '85,95p' ./monai/losses/unified_focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 532
🏁 Script executed:
#!/bin/bash
# Verify the entire test file to confirm no NONE reduction tests
cat ./tests/test_unified_focal_loss.py | grep -i "reduction" | head -20Repository: Project-MONAI/MONAI
Length of output: 129
🏁 Script executed:
#!/bin/bash
# Find test file location
find . -name "test_unified_focal_loss.py" -type fRepository: Project-MONAI/MONAI
Length of output: 104
🏁 Script executed:
#!/bin/bash
# Look for tests directory and list test files
ls -la ./tests/ | grep unifiedRepository: Project-MONAI/MONAI
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Check the test file
cat ./tests/losses/test_unified_focal_loss.py | grep -i "reduction\|NONE" | head -30Repository: Project-MONAI/MONAI
Length of output: 72
🏁 Script executed:
#!/bin/bash
# Count lines and get overview of tests
wc -l ./tests/losses/test_unified_focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 108
🏁 Script executed:
#!/bin/bash
# Check what test cases exist
grep -n "def test_\|reduction" ./tests/losses/test_unified_focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 274
🏁 Script executed:
#!/bin/bash
# Read the entire test file to understand test coverage
cat ./tests/losses/test_unified_focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 4265
🏁 Script executed:
#!/bin/bash
# Get the exact line numbers for the forward method of AsymmetricUnifiedFocalLoss
sed -n '280,300p' ./monai/losses/unified_focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 1000
🏁 Script executed:
#!/bin/bash
# Check the default reduction value in __init__ for AsymmetricUnifiedFocalLoss
sed -n '240,280p' ./monai/losses/unified_focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 1706
Shape mismatch in AsymmetricUnifiedFocalLoss when reduction=NONE
AsymmetricFocalLoss returns shape (B, H, W, [D]) (per-pixel) with reduction=NONE, while AsymmetricFocalTverskyLoss returns shape (B, C) (per-class). Line 296's addition will fail when both use reduction=NONE.
Recommend either:
- Document that
reduction=NONEis unsupported for AsymmetricUnifiedFocalLoss - Add runtime check to prevent NONE reduction in AsymmetricUnifiedFocalLoss
- Align output shapes so both return matching dimensions
Also add test coverage for reduction=NONE scenario.
🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 213 to 226, the final
concatenation/sum assumes compatible per-pixel vs per-class shapes and will
break when constituent losses return reduction=NONE (AsymmetricFocalLoss returns
per-pixel (B,H,W,...), AsymmetricFocalTverskyLoss returns per-class (B,C)), so
add a runtime guard that rejects reduction==LossReduction.NONE (raise a
ValueError with a clear message that AsymmetricUnifiedFocalLoss does not support
reduction=NONE) or alternatively coerce both component losses to use a supported
reduction (SUM/MEAN) before combining; update the docstring to state NONE is
unsupported for this unified loss and add unit tests asserting a ValueError is
raised when reduction=NONE and tests for SUM/MEAN behavior.
Signed-off-by: ytl0623 <david89062388@gmail.com>
05dac9e to
cbed38d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)
240-240: Remove unusednum_classesparameter.The
num_classesparameter is stored but never used in the implementation. It's not passed to internal loss instances and doesn't affect behavior. Either use it to validate inputs or remove it from the interface.
♻️ Duplicate comments (3)
monai/losses/unified_focal_loss.py (3)
83-83: Addstacklevel=2to warning for proper caller attribution.The warning call should specify
stacklevel=2to point to the caller's code rather than this line.🔎 Proposed fix
- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
188-188: Addstacklevel=2to warning for proper caller attribution.The warning call should specify
stacklevel=2to point to the caller's code.🔎 Proposed fix
- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
293-296: Handle or rejectreduction=NONEto prevent shape mismatch.When
reduction=NONE, AsymmetricFocalLoss returns shape(B, H, W, [D])(per-pixel) while AsymmetricFocalTverskyLoss returns shape(B, C)(per-class). Line 296's addition will fail. Either document and reject NONE reduction with a runtime check, or ensure both losses return compatible shapes.🔎 Proposed fix - reject NONE reduction
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: y_pred: Prediction logits. Shape: (B, C, H, W, [D]). Supports binary (C=1 or C=2) and multi-class (C>2) segmentation. y_true: Ground truth labels. Shape should match y_pred (or be indices if to_onehot_y is True). """ + if self.reduction == LossReduction.NONE.value: + raise ValueError("AsymmetricUnifiedFocalLoss does not support reduction='none' due to incompatible output shapes from component losses.") + if y_pred.shape != y_true.shape:
🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)
51-53: Clarify sigmoid behavior for multi-channel inputs.The docstring states sigmoid is used "for binary/multi-label" when
use_softmax=False, but doesn't clarify that sigmoid is applied independently to each channel in multi-channel cases. This differs from binary-only behavior where background channel is constructed. Consider adding: "For multi-channel inputs, sigmoid is applied per-channel independently (multi-label)."
159-159: Enhanceuse_softmaxdocumentation with usage guidance.Similar to AsymmetricFocalTverskyLoss, the docstring should clarify when to use softmax (mutually exclusive classes) vs. sigmoid (multi-label/overlapping classes). Consider adding: "Use
Truefor mutually exclusive multi-class segmentation,Falsefor binary or multi-label scenarios."
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
Signed-off-by: ytl0623 <david89062388@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)
175-200: Same redundant probability conversion issue as AsymmetricFocalTverskyLoss.Lines 175-182 duplicate the auto-handle logic from AsymmetricFocalTverskyLoss with the same flaw: single-channel always uses sigmoid regardless of
use_softmaxsetting.Apply the same fix as suggested for AsymmetricFocalTverskyLoss to respect
use_softmaxafter expanding to 2-channel.
🧹 Nitpick comments (3)
monai/losses/unified_focal_loss.py (3)
213-226: Reduction logic duplicates default fallback.Lines 220-226 explicitly handle MEAN/SUM/NONE, then line 226 returns
torch.mean(total_loss)as fallback. This fallback is unreachable if LossReduction enum is exhaustive.🔎 Simplification
Remove redundant fallback or add a warning if an unknown reduction is encountered:
if self.reduction == LossReduction.MEAN.value: return torch.mean(total_loss) if self.reduction == LossReduction.SUM.value: return torch.sum(total_loss) if self.reduction == LossReduction.NONE.value: return total_loss - return torch.mean(total_loss) + raise ValueError(f"Unsupported reduction: {self.reduction}")Same applies to AsymmetricFocalTverskyLoss lines 122-129.
288-291: Shape validation allows mismatch only for binary logits, but one-hot conversion happens downstream.Lines 288-291 permit shape mismatch if
is_binary_logits(C=1 with sigmoid) or ifto_onehot_y=True. However, the internal losses perform one-hot conversion independently. Ify_truehas mismatched shape andto_onehot_y=False, the internal losses will raise ValueError at their shape checks (lines 89, 194).This validation is redundant; the internal losses already enforce shape compatibility.
🔎 Simplification
Remove this check and let internal losses handle validation:
- if y_pred.shape != y_true.shape: - is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax - if not self.to_onehot_y and not is_binary_logits: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") -Or add a comment explaining why this pre-validation is needed.
89-89: Long exception messages flagged by static analysis.Lines 89, 194, and 291 embed long f-string messages directly in ValueError. Ruff (TRY003) suggests defining exception classes or message constants for long messages.
For consistency with MONAI conventions, verify if other loss modules use inline messages or constants. If this pattern is acceptable project-wide, ignore the hint.
Also applies to: 194-194, 291-291
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py
89-89: Avoid specifying long messages outside the exception class
(TRY003)
194-194: Avoid specifying long messages outside the exception class
(TRY003)
291-291: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (2)
monai/losses/unified_focal_loss.py (2)
266-279: Gamma parameters have opposite semantics in AsymmetricFocalLoss vs AsymmetricFocalTverskyLoss.AsymmetricFocalLoss uses gamma directly:
torch.pow(1 - y_pred, gamma), while AsymmetricFocalTverskyLoss uses its reciprocal:torch.pow(1 - dice_class, 1/gamma). Per the paper, Focal Tversky's optimal gamma=4/3 enhances loss (contrary to Focal loss which suppresses). Passing the same gamma=0.5 to both produces mismatched behaviors and may not match the paper's unified formulation intent.
114-115: The focal modulation formula is correct—it properly focuses on hard examples, not the reverse.With
gamma = 0.75(default), the exponent1/gamma = 1.333 > 1, which makes hard examples (low Dice values) contribute more to the loss than easy examples, not less. When you raise numbers to a power greater than 1, small values (easy examples where1 - diceis small) decrease more than large values (hard examples where1 - diceis large), so easy examples are down-weighted relative to hard examples. This is standard focal behavior and matches the docstring: "focal exponent value to down-weight easy foreground examples." The Unified Focal Loss specifies γ < 1 increases focusing on harder examples, and MONAI's reparameterization using1/gammaas the exponent achieves this—gamma = 0.75 yields exponent 1.333, which focuses on hard examples correctly.Likely an incorrect or invalid review comment.
| # Auto-handle single channel input (binary segmentation case) | ||
| if y_pred.shape[1] == 1 and not self.use_softmax: | ||
| y_pred = torch.sigmoid(y_pred) | ||
| y_pred = torch.cat([1 - y_pred, y_pred], dim=1) | ||
| is_already_prob = True | ||
| if y_true.shape[1] == 1: | ||
| y_true = one_hot(y_true, num_classes=2) | ||
| else: | ||
| is_already_prob = False | ||
|
|
||
| n_pred_ch = y_pred.shape[1] | ||
|
|
||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) | ||
| else: | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||
| if y_true.shape[1] != n_pred_ch: | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||
|
|
||
| if y_true.shape != y_pred.shape: | ||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||
|
|
||
| # clip the prediction to avoid NaN | ||
| # Convert logits to probabilities if not already done | ||
| if not is_already_prob: | ||
| if self.use_softmax: | ||
| y_pred = torch.softmax(y_pred, dim=1) | ||
| else: | ||
| y_pred = torch.sigmoid(y_pred) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant probability conversion when auto-handling single-channel input.
Lines 69-77 apply sigmoid and set is_already_prob=True, then lines 91-96 skip conversion when is_already_prob. However, if use_softmax=True with a single-channel input, the code applies sigmoid at line 71, but softmax would be more appropriate for the 2-channel representation created at line 72.
The logic forces sigmoid for single-channel regardless of use_softmax, which may not align with user intent.
🔎 Suggested approach
Apply the activation chosen by use_softmax after expanding to 2-channel:
- if y_pred.shape[1] == 1 and not self.use_softmax:
+ if y_pred.shape[1] == 1:
y_pred = torch.sigmoid(y_pred)
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
- is_already_prob = True
+ if self.use_softmax:
+ y_pred = torch.softmax(y_pred, dim=1)
+ is_already_prob = True
if y_true.shape[1] == 1:
y_true = one_hot(y_true, num_classes=2)
else:
is_already_prob = FalseOr document that single-channel always uses sigmoid regardless of use_softmax.
🧰 Tools
🪛 Ruff (0.14.10)
89-89: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 69-96, the code currently
always applies sigmoid when auto-handling a single-channel prediction, which is
wrong when use_softmax=True; instead, for single-channel inputs defer
activation: if y_pred has 1 channel and use_softmax is True, expand the logits
to two channels (create a compatible two-channel logits tensor, e.g., by
concatenating the negated and original logits or an appropriate pair) and leave
is_already_prob=False so the later softmax branch runs; if use_softmax is False,
perform the sigmoid, concatenate probabilities to two channels and set
is_already_prob=True; ensure subsequent one-hot conversion and shape checks
remain unchanged.
| # Calculate losses separately for each class | ||
| # Background: Standard Dice Loss | ||
| back_dice = 1 - dice_class[:, 0] | ||
| fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) | ||
|
|
||
| # Average class scores | ||
| loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) | ||
| return loss | ||
| # Foreground: Focal Tversky Loss | ||
| fore_dice = torch.pow(1 - dice_class[:, 1:], 1 / self.gamma) | ||
|
|
||
| # Concatenate background and foreground losses | ||
| # back_dice needs unsqueeze to match dimensions: (B,) -> (B, 1) | ||
| all_losses = torch.cat([back_dice.unsqueeze(1), fore_dice], dim=1) | ||
|
|
||
| # Apply reduction | ||
| if self.reduction == LossReduction.MEAN.value: | ||
| return torch.mean(all_losses) | ||
| if self.reduction == LossReduction.SUM.value: | ||
| return torch.sum(all_losses) | ||
| if self.reduction == LossReduction.NONE.value: | ||
| return all_losses | ||
|
|
||
| return torch.mean(all_losses) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return shape with reduction=NONE is (B, C), not per-pixel.
With reduction=NONE, line 127 returns all_losses with shape (B, C) (per-class). This differs from AsymmetricFocalLoss, which returns per-pixel shape. AsymmetricUnifiedFocalLoss combines both at line 296, causing a shape mismatch when reduction=NONE.
Past review comment flagged this but it remains unresolved.
🔎 Recommended fix
Add runtime guard in AsymmetricUnifiedFocalLoss:
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
+ if self.reduction == LossReduction.NONE.value:
+ raise ValueError("AsymmetricUnifiedFocalLoss does not support reduction='none' due to incompatible output shapes from constituent losses.")Update docstring to document this limitation and add test coverage.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 110 to 129, with
reduction=NONE the function returns all_losses shaped (B, C) (per-class) which
mismatches the per-pixel shape expected by AsymmetricUnifiedFocalLoss; add a
runtime guard to detect reduction==LossReduction.NONE and either (preferable per
the review) raise a clear ValueError explaining that AsymmetricUnifiedFocalLoss
cannot be used with reduction=NONE because this implementation returns per-class
losses, or alternatively change the reduction=NONE branch to compute and return
per-pixel losses to match AsymmetricFocalLoss; update the function docstring to
document this limitation (or the new behavior) and add unit tests that assert
the ValueError is raised (or validate the corrected per-pixel output shape) to
prevent regressions.
Fixes #8603
Description
Refactors
AsymmetricUnifiedFocalLossand its sub-components (AsymmetricFocalLoss,AsymmetricFocalTverskyLoss) to extend support from Binary-only to Multi-class segmentation, while also fixing mathematical logic errors and parameter passing bugs.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.