-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Add sigmoid/softmax interface for AsymmetricUnifiedFocalLoss #8669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,14 +22,13 @@ | |
|
|
||
| class AsymmetricFocalTverskyLoss(_Loss): | ||
| """ | ||
| AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. | ||
| AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss that prioritizes the foreground classes. | ||
|
|
||
| Actually, it's only supported for binary image segmentation now. | ||
| It supports both binary and multi-class segmentation. | ||
|
|
||
| Reimplementation of the Asymmetric Focal Tversky Loss described in: | ||
|
|
||
| - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", | ||
| Michael Yeung, Computerized Medical Imaging and Graphics | ||
| Michael Yeung, Computerized Medical Imaging and Graphics | ||
| """ | ||
|
|
||
| def __init__( | ||
|
|
@@ -39,119 +38,199 @@ def __init__( | |
| gamma: float = 0.75, | ||
| epsilon: float = 1e-7, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| use_softmax: bool = False, | ||
| ) -> None: | ||
| """ | ||
| Args: | ||
| to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. | ||
| delta : weight of the background. Defaults to 0.7. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. | ||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||
| delta: weight of the background class (used in the Tversky index denominator). Defaults to 0.7. | ||
| gamma: focal exponent value to down-weight easy foreground examples. Defaults to 0.75. | ||
| epsilon: a small value to prevent division by zero. Defaults to 1e-7. | ||
| reduction: {``"none"``, ``"mean"``, ``"sum"``} | ||
| Specifies the reduction to apply to the output. Defaults to ``"mean"``. | ||
| use_softmax: whether to use softmax to transform original logits into probabilities. | ||
| If True, softmax is used (for multi-class). If False, sigmoid is used (for binary/multi-label). | ||
| Defaults to False. | ||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| self.to_onehot_y = to_onehot_y | ||
| self.delta = delta | ||
| self.gamma = gamma | ||
| self.epsilon = epsilon | ||
| self.use_softmax = use_softmax | ||
|
|
||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Args: | ||
| y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims). | ||
| y_true: ground truth labels. Shape should match y_pred. | ||
| """ | ||
|
|
||
| # Auto-handle single channel input (binary segmentation case) | ||
| if y_pred.shape[1] == 1 and not self.use_softmax: | ||
| y_pred = torch.sigmoid(y_pred) | ||
| y_pred = torch.cat([1 - y_pred, y_pred], dim=1) | ||
| is_already_prob = True | ||
| # Expand y_true to match if it's single channel | ||
| if y_true.shape[1] == 1: | ||
| y_true = one_hot(y_true, num_classes=2) | ||
| else: | ||
| is_already_prob = False | ||
|
|
||
| n_pred_ch = y_pred.shape[1] | ||
|
|
||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) | ||
| else: | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||
| if y_true.shape[1] != n_pred_ch: | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||
|
|
||
| if y_true.shape != y_pred.shape: | ||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||
|
|
||
| # clip the prediction to avoid NaN | ||
| # Convert logits to probabilities if not already done | ||
| if not is_already_prob: | ||
| if self.use_softmax: | ||
| y_pred = torch.softmax(y_pred, dim=1) | ||
| else: | ||
| y_pred = torch.sigmoid(y_pred) | ||
|
|
||
| # Clip the prediction to avoid NaN | ||
| y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) | ||
|
|
||
| axis = list(range(2, len(y_pred.shape))) | ||
|
|
||
| # Calculate true positives (tp), false negatives (fn) and false positives (fp) | ||
| tp = torch.sum(y_true * y_pred, dim=axis) | ||
| fn = torch.sum(y_true * (1 - y_pred), dim=axis) | ||
| fp = torch.sum((1 - y_true) * y_pred, dim=axis) | ||
|
|
||
| dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) | ||
|
|
||
| # Calculate losses separately for each class, enhancing both classes | ||
| # Calculate losses separately for each class | ||
| # Background: Standard Dice Loss | ||
| back_dice = 1 - dice_class[:, 0] | ||
| fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) | ||
|
|
||
| # Average class scores | ||
| loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) | ||
| return loss | ||
| # Foreground: Focal Tversky Loss | ||
| fore_dice = torch.pow(1 - dice_class[:, 1:], 1 / self.gamma) | ||
|
|
||
| # Concatenate background and foreground losses | ||
| # back_dice needs unsqueeze to match dimensions: (B,) -> (B, 1) | ||
| all_losses = torch.cat([back_dice.unsqueeze(1), fore_dice], dim=1) | ||
|
|
||
| # Apply reduction | ||
| if self.reduction == LossReduction.MEAN.value: | ||
| return torch.mean(all_losses) | ||
| if self.reduction == LossReduction.SUM.value: | ||
| return torch.sum(all_losses) | ||
| if self.reduction == LossReduction.NONE.value: | ||
| return all_losses | ||
|
|
||
| return torch.mean(all_losses) | ||
|
|
||
|
|
||
| class AsymmetricFocalLoss(_Loss): | ||
| """ | ||
| AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. | ||
| AsymmetricFocalLoss is a variant of Focal Loss that treats background and foreground differently. | ||
|
|
||
| Actually, it's only supported for binary image segmentation now. | ||
| It supports both binary and multi-class segmentation. | ||
|
|
||
| Reimplementation of the Asymmetric Focal Loss described in: | ||
|
|
||
| - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", | ||
| Michael Yeung, Computerized Medical Imaging and Graphics | ||
| Michael Yeung, Computerized Medical Imaging and Graphics | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| to_onehot_y: bool = False, | ||
| delta: float = 0.7, | ||
| gamma: float = 2, | ||
| gamma: float = 2.0, | ||
| epsilon: float = 1e-7, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| use_softmax: bool = False, | ||
| ): | ||
| """ | ||
| Args: | ||
| to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. | ||
| delta : weight of the background. Defaults to 0.7. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. | ||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||
| to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. | ||
| delta: weight for the foreground classes. Defaults to 0.7. | ||
| gamma: focusing parameter for the background class (to down-weight easy background examples). Defaults to 2.0. | ||
| epsilon: a small value to prevent calculation errors. Defaults to 1e-7. | ||
| reduction: {``"none"``, ``"mean"``, ``"sum"``} | ||
| use_softmax: whether to use softmax to transform logits. Defaults to False. | ||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| self.to_onehot_y = to_onehot_y | ||
| self.delta = delta | ||
| self.gamma = gamma | ||
| self.epsilon = epsilon | ||
| self.use_softmax = use_softmax | ||
|
|
||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Args: | ||
| y_pred: prediction logits or probabilities. | ||
| y_true: ground truth labels. | ||
| """ | ||
|
|
||
| if y_pred.shape[1] == 1 and not self.use_softmax: | ||
| y_pred = torch.sigmoid(y_pred) | ||
| y_pred = torch.cat([1 - y_pred, y_pred], dim=1) | ||
| is_already_prob = True | ||
| if y_true.shape[1] == 1: | ||
| y_true = one_hot(y_true, num_classes=2) | ||
| else: | ||
| is_already_prob = False | ||
|
|
||
| n_pred_ch = y_pred.shape[1] | ||
|
|
||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) | ||
| else: | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||
| if y_true.shape[1] != n_pred_ch: | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||
|
|
||
| if y_true.shape != y_pred.shape: | ||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||
|
|
||
| if not is_already_prob: | ||
| if self.use_softmax: | ||
| y_pred = torch.softmax(y_pred, dim=1) | ||
| else: | ||
| y_pred = torch.sigmoid(y_pred) | ||
|
|
||
| y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) | ||
|
|
||
| cross_entropy = -y_true * torch.log(y_pred) | ||
|
|
||
| # Background (Channel 0): Focal Loss | ||
| back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] | ||
| back_ce = (1 - self.delta) * back_ce | ||
|
|
||
| fore_ce = cross_entropy[:, 1] | ||
| # Foreground (Channels 1+): Standard Cross Entropy | ||
| fore_ce = cross_entropy[:, 1:] | ||
| fore_ce = self.delta * fore_ce | ||
|
|
||
| loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) | ||
| return loss | ||
| # Concatenate losses | ||
| all_ce = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) | ||
|
|
||
| # Apply reduction | ||
| if self.reduction == LossReduction.MEAN.value: | ||
| return torch.mean(torch.sum(all_ce, dim=1)) | ||
| if self.reduction == LossReduction.SUM.value: | ||
| return torch.sum(all_ce) | ||
| if self.reduction == LossReduction.NONE.value: | ||
| return all_ce | ||
|
|
||
| class AsymmetricUnifiedFocalLoss(_Loss): | ||
| """ | ||
| AsymmetricUnifiedFocalLoss is a variant of Focal Loss. | ||
| return torch.mean(torch.sum(all_ce, dim=1)) | ||
|
|
||
| Actually, it's only supported for binary image segmentation now | ||
|
|
||
| Reimplementation of the Asymmetric Unified Focal Tversky Loss described in: | ||
| class AsymmetricUnifiedFocalLoss(_Loss): | ||
| """ | ||
| AsymmetricUnifiedFocalLoss is a wrapper that combines AsymmetricFocalLoss and AsymmetricFocalTverskyLoss. | ||
|
|
||
| - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", | ||
| Michael Yeung, Computerized Medical Imaging and Graphics | ||
| This unified loss allows for simultaneously optimizing distribution-based (CE) and region-based (Dice) metrics, | ||
| while handling class imbalance through asymmetric weighting. | ||
| """ | ||
|
|
||
| def __init__( | ||
|
|
@@ -162,79 +241,72 @@ def __init__( | |
| gamma: float = 0.5, | ||
| delta: float = 0.7, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| use_softmax: bool = False, | ||
| ): | ||
| """ | ||
| Args: | ||
| to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. | ||
| num_classes : number of classes, it only supports 2 now. Defaults to 2. | ||
| delta : weight of the background. Defaults to 0.7. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. | ||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||
| weight : weight for each loss function, if it's none it's 0.5. Defaults to None. | ||
|
|
||
| Example: | ||
| >>> import torch | ||
| >>> from monai.losses import AsymmetricUnifiedFocalLoss | ||
| >>> pred = torch.ones((1,1,32,32), dtype=torch.float32) | ||
| >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) | ||
| >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) | ||
| >>> fl(pred, grnd) | ||
| to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. | ||
| num_classes: number of classes. Defaults to 2. | ||
| weight: weight factor to balance between Focal Loss and Tversky Loss. | ||
| Loss = weight * FocalLoss + (1-weight) * TverskyLoss. Defaults to 0.5. | ||
| gamma: focal exponent. Defaults to 0.5. | ||
| delta: background/foreground balancing weight. Defaults to 0.7. | ||
| reduction: specifies the reduction to apply to the output. Defaults to "mean". | ||
| use_softmax: whether to use softmax for probability conversion. Defaults to False. | ||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| self.to_onehot_y = to_onehot_y | ||
| self.num_classes = num_classes | ||
| self.gamma = gamma | ||
| self.delta = delta | ||
| self.weight: float = weight | ||
| self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) | ||
| self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) | ||
| self.weight = weight | ||
| self.use_softmax = use_softmax | ||
|
|
||
| self.asy_focal_loss = AsymmetricFocalLoss( | ||
| gamma=self.gamma, | ||
| delta=self.delta, | ||
| use_softmax=self.use_softmax, | ||
| to_onehot_y=to_onehot_y, | ||
| reduction=LossReduction.NONE, | ||
| ) | ||
| self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( | ||
| gamma=self.gamma, | ||
| delta=self.delta, | ||
| use_softmax=self.use_softmax, | ||
| to_onehot_y=to_onehot_y, | ||
| reduction=LossReduction.NONE, | ||
| ) | ||
|
|
||
| # TODO: Implement this function to support multiple classes segmentation | ||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Args: | ||
| y_pred : the shape should be BNH[WD], where N is the number of classes. | ||
| It only supports binary segmentation. | ||
| The input should be the original logits since it will be transformed by | ||
| a sigmoid in the forward function. | ||
| y_true : the shape should be BNH[WD], where N is the number of classes. | ||
| It only supports binary segmentation. | ||
|
|
||
| Raises: | ||
| ValueError: When input and target are different shape | ||
| ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5 | ||
| ValueError: When num_classes | ||
| ValueError: When the number of classes entered does not match the expected number | ||
| y_pred: Prediction logits. Shape: (B, C, H, W, [D]). | ||
| Supports binary (C=1 or C=2) and multi-class (C>2) segmentation. | ||
| y_true: Ground truth labels. Shape should match y_pred (or be indices if to_onehot_y is True). | ||
| """ | ||
| if y_pred.shape != y_true.shape: | ||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||
|
|
||
| if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: | ||
| raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") | ||
|
|
||
| if y_pred.shape[1] == 1: | ||
| y_pred = one_hot(y_pred, num_classes=self.num_classes) | ||
| y_true = one_hot(y_true, num_classes=self.num_classes) | ||
|
|
||
| if torch.max(y_true) != self.num_classes - 1: | ||
| raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}") | ||
|
|
||
| n_pred_ch = y_pred.shape[1] | ||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| else: | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||
| is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax | ||
| if not self.to_onehot_y and not is_binary_logits: | ||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||
|
Comment on lines
287
to
+290
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Shape validation logic needs clarification. Line 288 defines 🔎 Add clarifying comment if y_pred.shape != y_true.shape:
+ # Allow mismatch when: (1) binary logits will be auto-expanded, or (2) y_true will be one-hot encoded
is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
if not self.to_onehot_y and not is_binary_logits:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")Based on static analysis hint. 🧰 Tools🪛 Ruff (0.14.10)290-290: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents |
||
|
|
||
| asy_focal_loss = self.asy_focal_loss(y_pred, y_true) | ||
| asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) | ||
|
|
||
| loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss | ||
| # Align Focal Loss to (B, C) by averaging over spatial dimensions | ||
| spatial_dims = list(range(2, len(asy_focal_loss.shape))) | ||
| focal_aligned = torch.mean(asy_focal_loss, dim=spatial_dims) | ||
|
|
||
| if self.reduction == LossReduction.SUM.value: | ||
| return torch.sum(loss) # sum over the batch and channel dims | ||
| if self.reduction == LossReduction.NONE.value: | ||
| return loss # returns [N, num_classes] losses | ||
| # Calculate weighted sum. Result shape: (B, C) | ||
| combined_loss = self.weight * focal_aligned + (1 - self.weight) * asy_focal_tversky_loss | ||
|
|
||
| loss: torch.Tensor | ||
| if self.reduction == LossReduction.MEAN.value: | ||
| return torch.mean(loss) | ||
| raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') | ||
| loss = torch.mean(combined_loss) | ||
| elif self.reduction == LossReduction.SUM.value: | ||
| loss = torch.sum(combined_loss) | ||
| elif self.reduction == LossReduction.NONE.value: | ||
| loss = combined_loss | ||
| else: | ||
| loss = torch.mean(combined_loss) | ||
|
|
||
| return loss | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return shape with reduction=NONE is (B, C), not per-pixel.
With
reduction=NONE, line 127 returnsall_losseswith shape(B, C)(per-class). This differs from AsymmetricFocalLoss, which returns per-pixel shape. AsymmetricUnifiedFocalLoss combines both at line 296, causing a shape mismatch whenreduction=NONE.Past review comment flagged this but it remains unresolved.
🔎 Recommended fix
Add runtime guard in AsymmetricUnifiedFocalLoss:
Update docstring to document this limitation and add test coverage.
🤖 Prompt for AI Agents