HutopoLoss
- class topolosses.losses.hutopo.HutopoLoss(*args: Any, **kwargs: Any)[source]
Bases:
_Loss
Topology-preserving segmentation loss combining pixel-wise and topological objectives.
- This loss has been defined in:
Hu et al. (2019) “Topology-Preserving Deep Image Segmentation” (NeurIPS),
This loss penalizes discrepancies between persistence diagrams of predicted and ground-truth segmentations using a Wasserstein distance on birth–death pairs. This loss can be used standalone or combined with a base segmentation loss via a weighting factor α.
- Parameters:
filtration_type (FiltrationType or string) –
- Choose how to build the topological filtration on probability maps:
sublevel: sublevel-set on raw output.
superlevel: sublevel-set on inverted output (1–p).
bothlevels: both SUBLEVEL and SUPERLEVEL via concatenation.
Defaults to SUPERLEVEL.
num_processes (int) – Number of parallel processes for persistent homology computations. Higher values may improve throughput. Defaults to 1.
include_background (bool) – Whether to include the background channel when computing topological loss. If False, only foreground classes are used. Defaults to False.
alpha (float) – Weight between the base segmentation loss and the topological loss. Total loss = base_loss + alpha * topo_loss. Defaults to 0.5.
softmax (bool) – If True, applies softmax to network outputs before loss computation. Mutually exclusive with sigmoid. Defaults to False.
sigmoid (bool) – If True, applies sigmoid activation to network outputs before loss computation. Mutually exclusive with softmax. Defaults to False.
use_base_loss (bool) – Whether to include a pixel-wise base loss component. If False, only the topological term is used and alpha is ignored. Defaults to True.
base_loss (Optional[_Loss]) – Custom base loss function (e.g., DiceLoss, CrossEntropy). If None and use_base_loss=True, a default Dice loss is used. Defaults to None.
- Raises:
ValueError – If both sigmoid and softmax are set to True simultaneously.
ValueError – If filtration_type is provided as a string but does not match any of the valid options (‘SUPERLEVEL’, ‘SUBLEVEL’, ‘BOTHLEVELS’).
- compute_hutopo_loss(prediction: torch.Tensor, target: torch.Tensor) List[torch.Tensor] [source]
Compute the hutopo loss as the topological discrepancy by matching prediction and target persistence diagrams via a squared-L2 Wasserstein distance on birth–death pairs.
- forward(input: torch.Tensor, target: torch.Tensor) torch.Tensor [source]
Calculates the forward pass of the HutopoLoss.
- Parameters:
input (Tensor) – Input tensor of shape (batch_size, num_classes, H, W).
target (Tensor) – Target tensor of shape (batch_size, num_classes, H, W).
- Returns:
The calculated betti matching loss.
- Return type:
Tensor
- Raises:
ValueError – If the shape of the ground truth is different from the input shape.
ValueError – If softmax=True and the number of channels for the prediction is 1.