TopographLoss

class topolosses.losses.topograph.TopographLoss(*args: Any, **kwargs: Any)[source]

Bases: _Loss

TopographLoss is a loss function designed to ensure strict topology preservation during image segmentation tasks.

The loss has been defined:

Lux et al (2024) Topograph: An efficient Graph-Based Framework for Strictly Topology Preserving Image Segmentation (https://arxiv.org/pdf/2411.03228)

By default the topograph component is combined with a dice loss comnponent. For more flexibility a custom base loss function can be passed.

Parameters:
  • num_processes (int) – Number of parallel processes to use for computation. TODO how exactly is this implemented and what does the user need to know?

  • use_c (bool) – Whether to use the C implementation (likely for performance) instead of a pure Python version. Defaults to True. TODO figure out if this option is useful if the package always comes with c++ extension.

  • sphere (bool) – If True, adds padding to create periodic boundary conditions (sphere topology). Defaults to False

  • eight_connectivity (bool) – Determines whether to use 8-connectivity for foreground components (i.e., diagonal adjacent pixels form a single connected component) versus 4-connectivity when building the component graph. Defaults to 8-connectivity.

  • aggregation (AggregationType) – Specifies the aggregation method for loss calculation across the batch. Possible values are mean, sum, max, min, ce, rms, and leg. Defaults to mean

  • thres_distr (ThresholdDistribution) – Determines the distribution used for sampling the binarization threshold. Possible values are uniform and gaussian. Defaults to None which is a constant binarization threshold of 0.5.

  • thres_var (float) – If a thres_distribution is set, this varibale controls the magnitude of random threshold variation applied during loss computation, with higher values increasing the noise. Defaults to 0.0.

  • include_background (bool) – If True, includes the background class in the topograph computation. Background inclusion in the base loss component should be controlled independently.

  • alpha (float) – Weighting factor for the topograph loss component. Is only applied if a base loss is used. Defaults to 0.1.

  • sigmoid (bool) – If True, applies a sigmoid activation to the input before computing the CLDice loss. Typically used for binary segmentation. Defaults to False.

  • softmax (bool) – If True, applies a softmax activation to the input before computing the CLDice loss. This is useful for multi-class segmentation tasks. Defaults to False. For other activation functions set sigmoid and softmax to false and apply the transformation before passing inputs to the loss.

  • use_base_component (bool) – if false the loss only consists of the Topograph component. A forward call will return the full Topograph component. base_loss, weights, and alpha will be ignored if this flag is set to false.

  • base_loss (_Loss, optional) – The base loss function to be used alongside the Topograph loss. Defaults to None, meaning a Dice component with default parameters will be used. .

Raises:

ValueError – If more than one of [sigmoid, softmax] is set to True.

compute_topopgraph_loss(input, target, starting_class, num_classes)[source]
forward(input: torch.Tensor, target: torch.Tensor) torch.Tensor[source]

Calculates the forward pass of the topograph loss.

Parameters:
  • input (Tensor) – Input tensor of shape (batch_size, num_classes, H, W).

  • target (Tensor) – Target tensor of shape (batch_size, num_classes, H, W).

Returns:

The calculated topological loss.

Return type:

Tensor

Raises:
  • ValueError – If the shape of the ground truth is different from the input shape.

  • ValueError – If softmax=True and the number of channels for the prediction is 1.