Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 164 additions & 92 deletions monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@

class AsymmetricFocalTverskyLoss(_Loss):
"""
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss that prioritizes the foreground classes.

Actually, it's only supported for binary image segmentation now.
It supports both binary and multi-class segmentation.

Reimplementation of the Asymmetric Focal Tversky Loss described in:

- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
Michael Yeung, Computerized Medical Imaging and Graphics
Michael Yeung, Computerized Medical Imaging and Graphics
"""

def __init__(
Expand All @@ -39,119 +38,199 @@ def __init__(
gamma: float = 0.75,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
) -> None:
"""
Args:
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
delta: weight of the background class (used in the Tversky index denominator). Defaults to 0.7.
gamma: focal exponent value to down-weight easy foreground examples. Defaults to 0.75.
epsilon: a small value to prevent division by zero. Defaults to 1e-7.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
use_softmax: whether to use softmax to transform original logits into probabilities.
If True, softmax is used (for multi-class). If False, sigmoid is used (for binary/multi-label).
Defaults to False.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon
self.use_softmax = use_softmax

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims).
y_true: ground truth labels. Shape should match y_pred.
"""

# Auto-handle single channel input (binary segmentation case)
if y_pred.shape[1] == 1 and not self.use_softmax:
y_pred = torch.sigmoid(y_pred)
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
is_already_prob = True
# Expand y_true to match if it's single channel
if y_true.shape[1] == 1:
y_true = one_hot(y_true, num_classes=2)
else:
is_already_prob = False

n_pred_ch = y_pred.shape[1]

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)
if y_true.shape[1] != n_pred_ch:
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

# clip the prediction to avoid NaN
# Convert logits to probabilities if not already done
if not is_already_prob:
if self.use_softmax:
y_pred = torch.softmax(y_pred, dim=1)
else:
y_pred = torch.sigmoid(y_pred)

# Clip the prediction to avoid NaN
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)

axis = list(range(2, len(y_pred.shape)))

# Calculate true positives (tp), false negatives (fn) and false positives (fp)
tp = torch.sum(y_true * y_pred, dim=axis)
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
fp = torch.sum((1 - y_true) * y_pred, dim=axis)

dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)

# Calculate losses separately for each class, enhancing both classes
# Calculate losses separately for each class
# Background: Standard Dice Loss
back_dice = 1 - dice_class[:, 0]
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)

# Average class scores
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
return loss
# Foreground: Focal Tversky Loss
fore_dice = torch.pow(1 - dice_class[:, 1:], 1 / self.gamma)

# Concatenate background and foreground losses
# back_dice needs unsqueeze to match dimensions: (B,) -> (B, 1)
all_losses = torch.cat([back_dice.unsqueeze(1), fore_dice], dim=1)

# Apply reduction
if self.reduction == LossReduction.MEAN.value:
return torch.mean(all_losses)
if self.reduction == LossReduction.SUM.value:
return torch.sum(all_losses)
if self.reduction == LossReduction.NONE.value:
return all_losses

return torch.mean(all_losses)
Comment on lines +111 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Return shape with reduction=NONE is (B, C), not per-pixel.

With reduction=NONE, line 127 returns all_losses with shape (B, C) (per-class). This differs from AsymmetricFocalLoss, which returns per-pixel shape. AsymmetricUnifiedFocalLoss combines both at line 296, causing a shape mismatch when reduction=NONE.

Past review comment flagged this but it remains unresolved.

🔎 Recommended fix

Add runtime guard in AsymmetricUnifiedFocalLoss:

     def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
+        if self.reduction == LossReduction.NONE.value:
+            raise ValueError("AsymmetricUnifiedFocalLoss does not support reduction='none' due to incompatible output shapes from constituent losses.")

Update docstring to document this limitation and add test coverage.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 110 to 129, with
reduction=NONE the function returns all_losses shaped (B, C) (per-class) which
mismatches the per-pixel shape expected by AsymmetricUnifiedFocalLoss; add a
runtime guard to detect reduction==LossReduction.NONE and either (preferable per
the review) raise a clear ValueError explaining that AsymmetricUnifiedFocalLoss
cannot be used with reduction=NONE because this implementation returns per-class
losses, or alternatively change the reduction=NONE branch to compute and return
per-pixel losses to match AsymmetricFocalLoss; update the function docstring to
document this limitation (or the new behavior) and add unit tests that assert
the ValueError is raised (or validate the corrected per-pixel output shape) to
prevent regressions.



class AsymmetricFocalLoss(_Loss):
"""
AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
AsymmetricFocalLoss is a variant of Focal Loss that treats background and foreground differently.

Actually, it's only supported for binary image segmentation now.
It supports both binary and multi-class segmentation.

Reimplementation of the Asymmetric Focal Loss described in:

- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
Michael Yeung, Computerized Medical Imaging and Graphics
Michael Yeung, Computerized Medical Imaging and Graphics
"""

def __init__(
self,
to_onehot_y: bool = False,
delta: float = 0.7,
gamma: float = 2,
gamma: float = 2.0,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
delta: weight for the foreground classes. Defaults to 0.7.
gamma: focusing parameter for the background class (to down-weight easy background examples). Defaults to 2.0.
epsilon: a small value to prevent calculation errors. Defaults to 1e-7.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
use_softmax: whether to use softmax to transform logits. Defaults to False.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon
self.use_softmax = use_softmax

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred: prediction logits or probabilities.
y_true: ground truth labels.
"""

if y_pred.shape[1] == 1 and not self.use_softmax:
y_pred = torch.sigmoid(y_pred)
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
is_already_prob = True
if y_true.shape[1] == 1:
y_true = one_hot(y_true, num_classes=2)
else:
is_already_prob = False

n_pred_ch = y_pred.shape[1]

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)
if y_true.shape[1] != n_pred_ch:
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

if not is_already_prob:
if self.use_softmax:
y_pred = torch.softmax(y_pred, dim=1)
else:
y_pred = torch.sigmoid(y_pred)

y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)

cross_entropy = -y_true * torch.log(y_pred)

# Background (Channel 0): Focal Loss
back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
back_ce = (1 - self.delta) * back_ce

fore_ce = cross_entropy[:, 1]
# Foreground (Channels 1+): Standard Cross Entropy
fore_ce = cross_entropy[:, 1:]
fore_ce = self.delta * fore_ce

loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
return loss
# Concatenate losses
all_ce = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1)

# Apply reduction
if self.reduction == LossReduction.MEAN.value:
return torch.mean(torch.sum(all_ce, dim=1))
if self.reduction == LossReduction.SUM.value:
return torch.sum(all_ce)
if self.reduction == LossReduction.NONE.value:
return all_ce

class AsymmetricUnifiedFocalLoss(_Loss):
"""
AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
return torch.mean(torch.sum(all_ce, dim=1))

Actually, it's only supported for binary image segmentation now

Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:
class AsymmetricUnifiedFocalLoss(_Loss):
"""
AsymmetricUnifiedFocalLoss is a wrapper that combines AsymmetricFocalLoss and AsymmetricFocalTverskyLoss.

- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
Michael Yeung, Computerized Medical Imaging and Graphics
This unified loss allows for simultaneously optimizing distribution-based (CE) and region-based (Dice) metrics,
while handling class imbalance through asymmetric weighting.
"""

def __init__(
Expand All @@ -162,79 +241,72 @@ def __init__(
gamma: float = 0.5,
delta: float = 0.7,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
num_classes : number of classes, it only supports 2 now. Defaults to 2.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.

Example:
>>> import torch
>>> from monai.losses import AsymmetricUnifiedFocalLoss
>>> pred = torch.ones((1,1,32,32), dtype=torch.float32)
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
>>> fl(pred, grnd)
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
num_classes: number of classes. Defaults to 2.
weight: weight factor to balance between Focal Loss and Tversky Loss.
Loss = weight * FocalLoss + (1-weight) * TverskyLoss. Defaults to 0.5.
gamma: focal exponent. Defaults to 0.5.
delta: background/foreground balancing weight. Defaults to 0.7.
reduction: specifies the reduction to apply to the output. Defaults to "mean".
use_softmax: whether to use softmax for probability conversion. Defaults to False.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.num_classes = num_classes
self.gamma = gamma
self.delta = delta
self.weight: float = weight
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
self.weight = weight
self.use_softmax = use_softmax

self.asy_focal_loss = AsymmetricFocalLoss(
gamma=self.gamma,
delta=self.delta,
use_softmax=self.use_softmax,
to_onehot_y=to_onehot_y,
reduction=LossReduction.NONE,
)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
gamma=self.gamma,
delta=self.delta,
use_softmax=self.use_softmax,
to_onehot_y=to_onehot_y,
reduction=LossReduction.NONE,
)

# TODO: Implement this function to support multiple classes segmentation
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.
The input should be the original logits since it will be transformed by
a sigmoid in the forward function.
y_true : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.

Raises:
ValueError: When input and target are different shape
ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
ValueError: When num_classes
ValueError: When the number of classes entered does not match the expected number
y_pred: Prediction logits. Shape: (B, C, H, W, [D]).
Supports binary (C=1 or C=2) and multi-class (C>2) segmentation.
y_true: Ground truth labels. Shape should match y_pred (or be indices if to_onehot_y is True).
"""
if y_pred.shape != y_true.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")

if y_pred.shape[1] == 1:
y_pred = one_hot(y_pred, num_classes=self.num_classes)
y_true = one_hot(y_true, num_classes=self.num_classes)

if torch.max(y_true) != self.num_classes - 1:
raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}")

n_pred_ch = y_pred.shape[1]
if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)
is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
if not self.to_onehot_y and not is_binary_logits:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
Comment on lines 287 to +290
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Shape validation logic needs clarification.

Line 288 defines is_binary_logits but only uses it at line 289. The condition allows shape mismatch for binary logits or when to_onehot_y=True, but the logic is unclear. Document when shape mismatch is expected vs an error.

🔎 Add clarifying comment
         if y_pred.shape != y_true.shape:
+            # Allow mismatch when: (1) binary logits will be auto-expanded, or (2) y_true will be one-hot encoded
             is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
             if not self.to_onehot_y and not is_binary_logits:
                 raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

Based on static analysis hint.

🧰 Tools
🪛 Ruff (0.14.10)

290-290: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 287-290, the current
shape-check treats a mismatch as OK for either to_onehot_y=True or for "binary
logits" (is_binary_logits) but has no explanation or clear rule for allowed
shapes; add a concise clarifying comment above this block that states the exact
allowed cases (e.g., when to_onehot_y=True we expect class-indexed y_true shapes
different from y_pred, and for binary logits y_pred has a channel dim of 1 while
y_true may omit that channel so a shape mismatch is acceptable), and tighten the
condition by explicitly checking the common binary case (y_true ndim equals
y_pred ndim - 1) before allowing the mismatch so behaviour is unambiguous.


asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)

loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss
# Align Focal Loss to (B, C) by averaging over spatial dimensions
spatial_dims = list(range(2, len(asy_focal_loss.shape)))
focal_aligned = torch.mean(asy_focal_loss, dim=spatial_dims)

if self.reduction == LossReduction.SUM.value:
return torch.sum(loss) # sum over the batch and channel dims
if self.reduction == LossReduction.NONE.value:
return loss # returns [N, num_classes] losses
# Calculate weighted sum. Result shape: (B, C)
combined_loss = self.weight * focal_aligned + (1 - self.weight) * asy_focal_tversky_loss

loss: torch.Tensor
if self.reduction == LossReduction.MEAN.value:
return torch.mean(loss)
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
loss = torch.mean(combined_loss)
elif self.reduction == LossReduction.SUM.value:
loss = torch.sum(combined_loss)
elif self.reduction == LossReduction.NONE.value:
loss = combined_loss
else:
loss = torch.mean(combined_loss)

return loss
Loading
Loading