TopographLoss
- class topolosses.losses.topograph.TopographLoss(*args: Any, **kwargs: Any)[source]
Bases:
_Loss
TopographLoss is a loss function designed to ensure strict topology preservation during image segmentation tasks.
- The loss has been defined:
Lux et al (2024) Topograph: An efficient Graph-Based Framework for Strictly Topology Preserving Image Segmentation (https://arxiv.org/pdf/2411.03228)
By default the topograph component is combined with a dice loss comnponent. For more flexibility a custom base loss function can be passed.
- Parameters:
num_processes (int) – Number of parallel processes to use for computation. TODO how exactly is this implemented and what does the user need to know?
use_c (bool) – Whether to use the C implementation (likely for performance) instead of a pure Python version. Defaults to True. TODO figure out if this option is useful if the package always comes with c++ extension.
sphere (bool) – If True, adds padding to create periodic boundary conditions (sphere topology). Defaults to False
eight_connectivity (bool) – Determines whether to use 8-connectivity for foreground components (i.e., diagonal adjacent pixels form a single connected component) versus 4-connectivity when building the component graph. Defaults to 8-connectivity.
aggregation (AggregationType) – Specifies the aggregation method for loss calculation across the batch. Possible values are mean, sum, max, min, ce, rms, and leg. Defaults to mean
thres_distr (ThresholdDistribution) – Determines the distribution used for sampling the binarization threshold. Possible values are uniform and gaussian. Defaults to None which is a constant binarization threshold of 0.5.
thres_var (float) – If a thres_distribution is set, this varibale controls the magnitude of random threshold variation applied during loss computation, with higher values increasing the noise. Defaults to 0.0.
include_background (bool) – If True, includes the background class in the topograph computation. Background inclusion in the base loss component should be controlled independently.
alpha (float) – Weighting factor for the topograph loss component. Is only applied if a base loss is used. Defaults to 0.1.
sigmoid (bool) – If True, applies a sigmoid activation to the input before computing the CLDice loss. Typically used for binary segmentation. Defaults to False.
softmax (bool) – If True, applies a softmax activation to the input before computing the CLDice loss. This is useful for multi-class segmentation tasks. Defaults to False. For other activation functions set sigmoid and softmax to false and apply the transformation before passing inputs to the loss.
use_base_component (bool) – if false the loss only consists of the Topograph component. A forward call will return the full Topograph component. base_loss, weights, and alpha will be ignored if this flag is set to false.
base_loss (_Loss, optional) – The base loss function to be used alongside the Topograph loss. Defaults to None, meaning a Dice component with default parameters will be used. .
- Raises:
ValueError – If more than one of [sigmoid, softmax] is set to True.
- forward(input: torch.Tensor, target: torch.Tensor) torch.Tensor [source]
Calculates the forward pass of the topograph loss.
- Parameters:
input (Tensor) – Input tensor of shape (batch_size, num_classes, H, W).
target (Tensor) – Target tensor of shape (batch_size, num_classes, H, W).
- Returns:
The calculated topological loss.
- Return type:
Tensor
- Raises:
ValueError – If the shape of the ground truth is different from the input shape.
ValueError – If softmax=True and the number of channels for the prediction is 1.