WarpingLoss

class topolosses.losses.warping.WarpingLoss(*args: Any, **kwargs: Any)[source]

Bases: _Loss

A topology-aware loss function that emphasizes structurally critical pixels during segmentation.

The loss has been defined in:

Hu (2022) Structure-Aware Image Segmentation with Homotopy Warping (NeurIPS).

This loss identifies topologically sensitive false positives and false negatives using distance transforms, then selectively applies a cross-entropy loss on these critical points to preserve object connectivity and structure. It is especially suited for applications requiring high topological fidelity.

Parameters:
  • eight_connectivity (bool) – Determines whether to use 8-connectivity for foreground components (i.e., diagonal adjacent pixels form a single connected component) versus 4-connectivity when building the component graph. Defaults to 8-connectivity.

  • alpha (float) – Weighting factor for combining the base loss and the topology loss (i.e.: base_loss + alpha*topology_loss). Defaults to 0.5.

  • sigmoid (bool) – If True, applies a sigmoid activation to the input before computing the loss. Sigmoid is not applied before passing it to a custom base loss function. Defaults to False.

  • softmax (bool) – If True, applies a softmax activation to the input before computing the loss. Softmax is not applied before passing it to a custom base loss function. Defaults to False.

  • use_base_loss (bool) – If False, the loss only consists of the topology component. The base_loss and alpha will be ignored if this flag is set to false. Defaults to True.

  • base_loss (_Loss, optional) – The base loss function to be used alongside the topology loss. Defaults to None, meaning a standard cross-entropy loss will be used.

Raises:

ValueError – If more than one of [sigmoid, softmax] is set to True.

compute_warping_loss(input, target)[source]

Compute cross-entropy loss only on pixels critical to preserving segmentation topology.

forward(input: torch.Tensor, target: torch.Tensor) torch.Tensor[source]

Calculates the forward pass of the Mosin Loss.

Parameters:
  • input (Tensor) – Input tensor of shape (batch_size, num_classes, H, W).

  • target (Tensor) – Target tensor of shape (batch_size, num_classes, H, W).

Returns:

The calculated betti matching loss.

Return type:

Tensor

Raises:
  • ValueError – If the shape of the ground truth is different from the input shape.

  • ValueError – If the number of classe is smaller than 2.