BettiMatchingLoss

class topolosses.losses.betti_matching.BettiMatchingLoss(*args: Any, **kwargs: Any)[source]

Bases: _Loss

BettiMatchingLoss is a topology-aware loss function that ensures spatially and feature-wise accurate topology preservation in image segmentation tasks.

The loss function is based on Betti matching, a concept from persistent homology that enables a s patially correct matching of topological features via induced matchings of persistence barcodes.

The method has been introduced and refined in the following works: - Stucki et al. (2023) “Topologically Faithful Image Segmentation via Induced Matching of Persistence Barcodes” - Stucki et al. (2024) “Efficient Betti Matching Enables Topology-Aware 3D Segmentation via Persistent Homology” - Berger et al. (2024) “Topologically Faithful Multi-class Segmentation in Medical Images”

By default, the Betti matching component is combined with a dice loss comnponent. For more flexibility, it can be combined with other base loss functions or used as a standalone topology-aware loss..

Parameters:
  • filtration_type (str) – Determines how the filtration is computed: - superlevel: Features appear as input values decrease - sublevel: Features appear as input values increase - bothlevels: Applies both filtration types and combines results

  • num_processes (int) – Number of parallel processes for computing Betti matching

  • push_unmatched_to_1_0 (bool) – If True, pushes unmatched birth points toward 1 and death points toward 0. If False, simply pushes birth and death points together.

  • barcode_length_threshold (float) – Minimum persistence (birth-death) threshold to filter out short-lived topological features that may be noise.

  • topology_weights (tuple[float, float]) – Tuple of weights (matched, unmatched) controlling the importance of: - matched features between prediction and target - unmatched features in prediction

  • sphere – If True, adds padding to create periodic boundary conditions (sphere topology). Defaults to False

  • include_background (bool) – If True, includes the background class in the topology-aware computation. Background inclusion in the base loss component should be controlled independently. Defaults to False.

  • alpha (float) – Weighting factor for the topology-aware loss component. Only applied if a base loss is used. Defaults to 0.5.

  • sigmoid (bool) – If True, applies sigmoid activation to the forward pass input before computing the topology-aware component. If using the default Dice loss, the sigmoid-transformed input is also used. For custom base losses, the raw input is passed. Typically used for binary segmentation. Default: False.

  • softmax (bool) – If True, applies softmax to the forward pass input before computing the topology-aware component. If using the default Dice loss, the softmax-transformed input is also used. For custom base losses, the raw input is passed. Default: False.

  • use_base_loss (bool) – If False, the loss consists only of the topology-aware component. A forward call will return the full topology-aware component. base_loss, weights, and alpha will be ignored if this flag is set to False.

  • base_loss (_Loss, optional) – The base loss function to be used alongside the topology-aware loss. Defaults to None, meaning a Dice component with default parameters will be used.

Raises:
  • ValueError – If more than one of [sigmoid, softmax] is set to True.

  • ValueError – If topology_weights is not a list of lenght 2

compute_betti_matching_loss(input: torch.Tensor, target: torch.Tensor) tuple[torch.Tensor, list[torch.Tensor]][source]

Compute the Betti matching loss for batched input and target tensors.

Processes input and target tensors through the appropriate filtration transformations, computes matching between persistence barcodes, and aggregates the loss values across all instances in the batch.

forward(input: torch.Tensor, target: torch.Tensor) torch.Tensor[source]

Calculates the forward pass of the betti matching loss.

Parameters:
  • input (Tensor) – Input tensor of shape (batch_size, num_classes, H, W).

  • target (Tensor) – Target tensor of shape (batch_size, num_classes, H, W).

Returns:

The calculated betti matching loss.

Return type:

Tensor

Raises:
  • ValueError – If the shape of the ground truth is different from the input shape.

  • ValueError – If softmax=True and the number of channels for the prediction is 1.