WarpingLoss
- class topolosses.losses.warping.WarpingLoss(*args: Any, **kwargs: Any)[source]
Bases:
_Loss
A topology-aware loss function that emphasizes structurally critical pixels during segmentation.
- The loss has been defined in:
Hu (2022) Structure-Aware Image Segmentation with Homotopy Warping (NeurIPS).
This loss identifies topologically sensitive false positives and false negatives using distance transforms, then selectively applies a cross-entropy loss on these critical points to preserve object connectivity and structure. It is especially suited for applications requiring high topological fidelity.
- Parameters:
eight_connectivity (bool) – Determines whether to use 8-connectivity for foreground components (i.e., diagonal adjacent pixels form a single connected component) versus 4-connectivity when building the component graph. Defaults to 8-connectivity.
alpha (float) – Weighting factor for combining the base loss and the topology loss (i.e.: base_loss + alpha*topology_loss). Defaults to 0.5.
sigmoid (bool) – If True, applies a sigmoid activation to the input before computing the loss. Sigmoid is not applied before passing it to a custom base loss function. Defaults to False.
softmax (bool) – If True, applies a softmax activation to the input before computing the loss. Softmax is not applied before passing it to a custom base loss function. Defaults to False.
use_base_loss (bool) – If False, the loss only consists of the topology component. The base_loss and alpha will be ignored if this flag is set to false. Defaults to True.
base_loss (_Loss, optional) – The base loss function to be used alongside the topology loss. Defaults to None, meaning a standard cross-entropy loss will be used.
- Raises:
ValueError – If more than one of [sigmoid, softmax] is set to True.
- compute_warping_loss(input, target)[source]
Compute cross-entropy loss only on pixels critical to preserving segmentation topology.
- forward(input: torch.Tensor, target: torch.Tensor) torch.Tensor [source]
Calculates the forward pass of the Mosin Loss.
- Parameters:
input (Tensor) – Input tensor of shape (batch_size, num_classes, H, W).
target (Tensor) – Target tensor of shape (batch_size, num_classes, H, W).
- Returns:
The calculated betti matching loss.
- Return type:
Tensor
- Raises:
ValueError – If the shape of the ground truth is different from the input shape.
ValueError – If the number of classe is smaller than 2.