From 45d9877b4eb4d149841455468e1e5595b6362910 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 23 Dec 2025 13:09:53 +0800 Subject: [PATCH 1/4] Add sigmoid/softmax interface for AsymmetricUnifiedFocalLoss Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 240 +++++++++++++++--------- tests/losses/test_unified_focal_loss.py | 105 ++++++++--- 2 files changed, 224 insertions(+), 121 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 8484eb67ed..b7afc5ec45 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -22,14 +22,13 @@ class AsymmetricFocalTverskyLoss(_Loss): """ - AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. + AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss that prioritizes the foreground classes. - Actually, it's only supported for binary image segmentation now. + It supports both binary and multi-class segmentation. Reimplementation of the Asymmetric Focal Tversky Loss described in: - - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", - Michael Yeung, Computerized Medical Imaging and Graphics + Michael Yeung, Computerized Medical Imaging and Graphics """ def __init__( @@ -39,119 +38,200 @@ def __init__( gamma: float = 0.75, epsilon: float = 1e-7, reduction: LossReduction | str = LossReduction.MEAN, + use_softmax: bool = False, ) -> None: """ Args: to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. - delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. + delta: weight of the background class (used in the Tversky index denominator). Defaults to 0.7. + gamma: focal exponent value to down-weight easy foreground examples. Defaults to 0.75. + epsilon: a small value to prevent division by zero. Defaults to 1e-7. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + use_softmax: whether to use softmax to transform original logits into probabilities. + If True, softmax is used (for multi-class). If False, sigmoid is used (for binary/multi-label). + Defaults to False. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y self.delta = delta self.gamma = gamma self.epsilon = epsilon + self.use_softmax = use_softmax def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: + """ + Args: + y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims). + y_true: ground truth labels. Shape should match y_pred. + """ + + # Auto-handle single channel input (binary segmentation case) + if y_pred.shape[1] == 1 and not self.use_softmax: + y_pred = torch.sigmoid(y_pred) + y_pred = torch.cat([1 - y_pred, y_pred], dim=1) + is_already_prob = True + if y_true.shape[1] == 1: + y_true = one_hot(y_true, num_classes=2) + else: + is_already_prob = False + n_pred_ch = y_pred.shape[1] if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: - y_true = one_hot(y_true, num_classes=n_pred_ch) + if y_true.shape[1] != n_pred_ch: + y_true = one_hot(y_true, num_classes=n_pred_ch) if y_true.shape != y_pred.shape: raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - # clip the prediction to avoid NaN + # Convert logits to probabilities if not already done + if not is_already_prob: + if self.use_softmax: + y_pred = torch.softmax(y_pred, dim=1) + else: + y_pred = torch.sigmoid(y_pred) + + # Clip the prediction to avoid NaN y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) + axis = list(range(2, len(y_pred.shape))) # Calculate true positives (tp), false negatives (fn) and false positives (fp) tp = torch.sum(y_true * y_pred, dim=axis) fn = torch.sum(y_true * (1 - y_pred), dim=axis) fp = torch.sum((1 - y_true) * y_pred, dim=axis) + dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) - # Calculate losses separately for each class, enhancing both classes + # Calculate losses separately for each class + # Background: Standard Dice Loss back_dice = 1 - dice_class[:, 0] - fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) - # Average class scores - loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) - return loss + # Foreground: Focal Tversky Loss + fore_dice = torch.pow(1 - dice_class[:, 1:], 1 / self.gamma) + + # Concatenate background and foreground losses + # back_dice needs unsqueeze to match dimensions: (B,) -> (B, 1) + all_losses = torch.cat([back_dice.unsqueeze(1), fore_dice], dim=1) + + # Apply reduction + if self.reduction == LossReduction.MEAN.value: + return torch.mean(all_losses) + if self.reduction == LossReduction.SUM.value: + return torch.sum(all_losses) + if self.reduction == LossReduction.NONE.value: + return all_losses + + return torch.mean(all_losses) class AsymmetricFocalLoss(_Loss): """ - AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. + AsymmetricFocalLoss is a variant of Focal Loss that treats background and foreground differently. - Actually, it's only supported for binary image segmentation now. + It supports both binary and multi-class segmentation. Reimplementation of the Asymmetric Focal Loss described in: - - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", - Michael Yeung, Computerized Medical Imaging and Graphics + Michael Yeung, Computerized Medical Imaging and Graphics """ def __init__( self, to_onehot_y: bool = False, delta: float = 0.7, - gamma: float = 2, + gamma: float = 2.0, epsilon: float = 1e-7, reduction: LossReduction | str = LossReduction.MEAN, + use_softmax: bool = False, ): """ Args: - to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. - delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + delta: weight for the foreground classes. Defaults to 0.7. + gamma: focusing parameter for the background class (to down-weight easy background examples). Defaults to 2.0. + epsilon: a small value to prevent calculation errors. Defaults to 1e-7. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + use_softmax: whether to use softmax to transform logits. Defaults to False. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y self.delta = delta self.gamma = gamma self.epsilon = epsilon + self.use_softmax = use_softmax def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: + """ + Args: + y_pred: prediction logits or probabilities. + y_true: ground truth labels. + """ + + if y_pred.shape[1] == 1 and not self.use_softmax: + y_pred = torch.sigmoid(y_pred) + y_pred = torch.cat([1 - y_pred, y_pred], dim=1) + is_already_prob = True + if y_true.shape[1] == 1: + y_true = one_hot(y_true, num_classes=2) + else: + is_already_prob = False + n_pred_ch = y_pred.shape[1] if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: - y_true = one_hot(y_true, num_classes=n_pred_ch) + if y_true.shape[1] != n_pred_ch: + y_true = one_hot(y_true, num_classes=n_pred_ch) if y_true.shape != y_pred.shape: raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") + if not is_already_prob: + if self.use_softmax: + y_pred = torch.softmax(y_pred, dim=1) + else: + y_pred = torch.sigmoid(y_pred) + y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) + cross_entropy = -y_true * torch.log(y_pred) + # Background (Channel 0): Focal Loss back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] back_ce = (1 - self.delta) * back_ce - - fore_ce = cross_entropy[:, 1] + # Foreground (Channels 1+): Standard Cross Entropy + fore_ce = cross_entropy[:, 1:] fore_ce = self.delta * fore_ce - loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) - return loss + # Concatenate losses + all_ce = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) + # Sum over classes (dim=1) to get total loss per pixel + total_loss = torch.sum(all_ce, dim=1) -class AsymmetricUnifiedFocalLoss(_Loss): - """ - AsymmetricUnifiedFocalLoss is a variant of Focal Loss. + # Apply reduction + if self.reduction == LossReduction.MEAN.value: + return torch.mean(total_loss) + if self.reduction == LossReduction.SUM.value: + return torch.sum(total_loss) + if self.reduction == LossReduction.NONE.value: + return total_loss + return torch.mean(total_loss) - Actually, it's only supported for binary image segmentation now - Reimplementation of the Asymmetric Unified Focal Tversky Loss described in: +class AsymmetricUnifiedFocalLoss(_Loss): + """ + AsymmetricUnifiedFocalLoss is a wrapper that combines AsymmetricFocalLoss and AsymmetricFocalTverskyLoss. - - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", - Michael Yeung, Computerized Medical Imaging and Graphics + This unified loss allows for simultaneously optimizing distribution-based (CE) and region-based (Dice) metrics, + while handling class imbalance through asymmetric weighting. """ def __init__( @@ -162,79 +242,57 @@ def __init__( gamma: float = 0.5, delta: float = 0.7, reduction: LossReduction | str = LossReduction.MEAN, + use_softmax: bool = False, ): """ Args: - to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. - num_classes : number of classes, it only supports 2 now. Defaults to 2. - delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. - weight : weight for each loss function, if it's none it's 0.5. Defaults to None. - - Example: - >>> import torch - >>> from monai.losses import AsymmetricUnifiedFocalLoss - >>> pred = torch.ones((1,1,32,32), dtype=torch.float32) - >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) - >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) - >>> fl(pred, grnd) + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + num_classes: number of classes. Defaults to 2. + weight: weight factor to balance between Focal Loss and Tversky Loss. + Loss = weight * FocalLoss + (1-weight) * TverskyLoss. Defaults to 0.5. + gamma: focal exponent. Defaults to 0.5. + delta: background/foreground balancing weight. Defaults to 0.7. + reduction: specifies the reduction to apply to the output. Defaults to "mean". + use_softmax: whether to use softmax for probability conversion. Defaults to False. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y self.num_classes = num_classes self.gamma = gamma self.delta = delta - self.weight: float = weight - self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) - self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) + self.weight = weight + self.use_softmax = use_softmax + + self.asy_focal_loss = AsymmetricFocalLoss( + gamma=self.gamma, + delta=self.delta, + use_softmax=self.use_softmax, + to_onehot_y=to_onehot_y, + reduction=reduction, + ) + self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( + gamma=self.gamma, + delta=self.delta, + use_softmax=self.use_softmax, + to_onehot_y=to_onehot_y, + reduction=reduction, + ) - # TODO: Implement this function to support multiple classes segmentation def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: - y_pred : the shape should be BNH[WD], where N is the number of classes. - It only supports binary segmentation. - The input should be the original logits since it will be transformed by - a sigmoid in the forward function. - y_true : the shape should be BNH[WD], where N is the number of classes. - It only supports binary segmentation. - - Raises: - ValueError: When input and target are different shape - ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5 - ValueError: When num_classes - ValueError: When the number of classes entered does not match the expected number + y_pred: Prediction logits. Shape: (B, C, H, W, [D]). + Supports binary (C=1 or C=2) and multi-class (C>2) segmentation. + y_true: Ground truth labels. Shape should match y_pred (or be indices if to_onehot_y is True). """ if y_pred.shape != y_true.shape: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - - if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: - raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") - - if y_pred.shape[1] == 1: - y_pred = one_hot(y_pred, num_classes=self.num_classes) - y_true = one_hot(y_true, num_classes=self.num_classes) - - if torch.max(y_true) != self.num_classes - 1: - raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}") - - n_pred_ch = y_pred.shape[1] - if self.to_onehot_y: - if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - else: - y_true = one_hot(y_true, num_classes=n_pred_ch) + is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax + if not self.to_onehot_y and not is_binary_logits: + raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") asy_focal_loss = self.asy_focal_loss(y_pred, y_true) asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) - loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss + loss = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss - if self.reduction == LossReduction.SUM.value: - return torch.sum(loss) # sum over the batch and channel dims - if self.reduction == LossReduction.NONE.value: - return loss # returns [N, num_classes] losses - if self.reduction == LossReduction.MEAN.value: - return torch.mean(loss) - raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + return loss diff --git a/tests/losses/test_unified_focal_loss.py b/tests/losses/test_unified_focal_loss.py index 3b868a560e..ccf96496b6 100644 --- a/tests/losses/test_unified_focal_loss.py +++ b/tests/losses/test_unified_focal_loss.py @@ -19,47 +19,92 @@ from monai.losses import AsymmetricUnifiedFocalLoss -TEST_CASES = [ - [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) - { - "y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), - "y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), - }, - 0.0, - ], - [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) - { - "y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), - "y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), - }, - 0.0, - ], +# 1. Binary Case (Logits input): Prediction matches GT perfectly +# Input Shape: (B, 1, H, W) -> Auto expanded internally +TEST_CASE_BINARY_LOGITS = [ + {"y_pred": torch.tensor([[[[10.0, -10.0], [-10.0, 10.0]]]]), "y_true": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])}, + 0.0, + {"use_softmax": False, "to_onehot_y": False}, +] + +# 2. Binary Case (2 Channels input): Prediction matches GT perfectly +# Input Shape: (B, 2, H, W) +TEST_CASE_BINARY_2CH = [ + { + "y_pred": torch.tensor( + [[[[-10.0, 10.0], [10.0, -10.0]], [[10.0, -10.0], [-10.0, 10.0]]]] # Ch0 (Background): Low, High, High, Low + ), # Ch1 (Foreground): High, Low, Low, High + "y_true": torch.tensor([[[[1, 0], [0, 1]]]]), + }, + 0.0, + {"use_softmax": True, "to_onehot_y": True}, +] + +# 3. Multi-Class Case (3 Channels): Prediction matches GT perfectly +TEST_CASE_MULTICLASS_PERFECT = [ + { + "y_pred": torch.tensor( + [ + [ + [[10.0, -10.0], [-10.0, 10.0]], # Class 0 Logits + [[-10.0, 10.0], [-10.0, -10.0]], # Class 1 Logits + [[-10.0, -10.0], [10.0, -10.0]], + ] + ] + ), # Class 2 Logits + "y_true": torch.tensor([[[[0, 1], [2, 0]]]]), # Indices + }, + 0.0, + {"use_softmax": True, "to_onehot_y": True}, +] + +# 4. Multi-Class Case: Wrong Prediction +TEST_CASE_MULTICLASS_WRONG = [ + { + "y_pred": torch.tensor( + [[[[-10.0, -10.0], [-10.0, -10.0]], [[10.0, 10.0], [10.0, 10.0]], [[-10.0, -10.0], [-10.0, -10.0]]]] + ), + "y_true": torch.tensor([[[[0, 0], [0, 0]]]]), # GT is class 0, but Pred is class 1 + }, + None, + {"use_softmax": True, "to_onehot_y": True}, ] class TestAsymmetricUnifiedFocalLoss(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_result(self, input_data, expected_val): - loss = AsymmetricUnifiedFocalLoss() - result = loss(**input_data) - np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + @parameterized.expand([TEST_CASE_BINARY_LOGITS, TEST_CASE_BINARY_2CH, TEST_CASE_MULTICLASS_PERFECT]) + def test_perfect_prediction(self, input_data, expected_val, args): + loss_func = AsymmetricUnifiedFocalLoss(**args) + result = loss_func(**input_data) + # We use a small tolerance because 10.0 logits is not exactly probability 1.0 + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-3, rtol=1e-3) + + @parameterized.expand([TEST_CASE_MULTICLASS_WRONG]) + def test_wrong_prediction(self, input_data, expected_val, args): + loss_func = AsymmetricUnifiedFocalLoss(**args) + result = loss_func(**input_data) + self.assertGreater(result.item(), 1.0, "Loss should be high for wrong predictions") def test_ill_shape(self): loss = AsymmetricUnifiedFocalLoss() - with self.assertRaisesRegex(ValueError, ""): - loss(torch.ones((2, 2, 2)), torch.ones((2, 2, 2, 2))) + with self.assertRaisesRegex(ValueError, "ground truth has different shape"): + loss(torch.ones((1, 1, 4, 4)), torch.ones((1, 1, 2, 2))) def test_with_cuda(self): - loss = AsymmetricUnifiedFocalLoss() - i = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) - j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) - if torch.cuda.is_available(): - i = i.cuda() - j = j.cuda() + if not torch.cuda.is_available(): + print("CUDA not available, skipping test_with_cuda") + return + + loss = AsymmetricUnifiedFocalLoss(use_softmax=False, to_onehot_y=False) + # Binary logits case on GPU + i = torch.tensor([[[[10.0, 0], [0, 10.0]]], [[[10.0, 0], [0, 10.0]]]]).cuda() + j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]).cuda() + output = loss(i, j) - print(output) - np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + print(f"CUDA Output: {output.item()}") + self.assertTrue(output.is_cuda) + self.assertLess(output.item(), 1.0) if __name__ == "__main__": From cbed38d7740b99aa304116f7b869c041a79e8149 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 23 Dec 2025 13:23:24 +0800 Subject: [PATCH 2/4] fix: Returning Any from function declared to return Tensor Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index b7afc5ec45..a3b99ef932 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -293,6 +293,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: asy_focal_loss = self.asy_focal_loss(y_pred, y_true) asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) - loss = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss + loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss return loss From b08de6563ab83c80116486761c7873014f646fe2 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 23 Dec 2025 15:50:36 +0800 Subject: [PATCH 3/4] add stacklevel=2 to warning Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index a3b99ef932..0e0d164739 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -80,7 +80,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else: if y_true.shape[1] != n_pred_ch: y_true = one_hot(y_true, num_classes=n_pred_ch) @@ -185,7 +185,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else: if y_true.shape[1] != n_pred_ch: y_true = one_hot(y_true, num_classes=n_pred_ch) From 7a100d92927ad96568febfc45c7e05017b84d712 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 23 Dec 2025 16:19:07 +0800 Subject: [PATCH 4/4] fix: reduction=NONE Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 34 +++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 0e0d164739..d7694f0592 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -71,6 +71,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_pred = torch.sigmoid(y_pred) y_pred = torch.cat([1 - y_pred, y_pred], dim=1) is_already_prob = True + # Expand y_true to match if it's single channel if y_true.shape[1] == 1: y_true = one_hot(y_true, num_classes=2) else: @@ -213,17 +214,15 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: # Concatenate losses all_ce = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) - # Sum over classes (dim=1) to get total loss per pixel - total_loss = torch.sum(all_ce, dim=1) - # Apply reduction if self.reduction == LossReduction.MEAN.value: - return torch.mean(total_loss) + return torch.mean(torch.sum(all_ce, dim=1)) if self.reduction == LossReduction.SUM.value: - return torch.sum(total_loss) + return torch.sum(all_ce) if self.reduction == LossReduction.NONE.value: - return total_loss - return torch.mean(total_loss) + return all_ce + + return torch.mean(torch.sum(all_ce, dim=1)) class AsymmetricUnifiedFocalLoss(_Loss): @@ -268,14 +267,14 @@ def __init__( delta=self.delta, use_softmax=self.use_softmax, to_onehot_y=to_onehot_y, - reduction=reduction, + reduction=LossReduction.NONE, ) self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( gamma=self.gamma, delta=self.delta, use_softmax=self.use_softmax, to_onehot_y=to_onehot_y, - reduction=reduction, + reduction=LossReduction.NONE, ) def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: @@ -293,6 +292,21 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: asy_focal_loss = self.asy_focal_loss(y_pred, y_true) asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) - loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss + # Align Focal Loss to (B, C) by averaging over spatial dimensions + spatial_dims = list(range(2, len(asy_focal_loss.shape))) + focal_aligned = torch.mean(asy_focal_loss, dim=spatial_dims) + + # Calculate weighted sum. Result shape: (B, C) + combined_loss = self.weight * focal_aligned + (1 - self.weight) * asy_focal_tversky_loss + + loss: torch.Tensor + if self.reduction == LossReduction.MEAN.value: + loss = torch.mean(combined_loss) + elif self.reduction == LossReduction.SUM.value: + loss = torch.sum(combined_loss) + elif self.reduction == LossReduction.NONE.value: + loss = combined_loss + else: + loss = torch.mean(combined_loss) return loss