MosinLoss
- class topolosses.losses.mosin.MosinLoss(*args: Any, **kwargs: Any)[source]
Bases:
_Loss
A topology-aware loss function for curvilinear structure delineation using perceptual features.
- The loss has been defined in:
Mosinska et al. (2018) Beyond the Pixel-Wise Loss for Topology-Aware Delineation.
This loss uses a pre-trained VGG19 network to extract multi-level features from predictions and targets, comparing them to enforce topological consistency. By default, it combines with a pixel-wise base loss.
- Parameters:
include_background (bool) – If True, includes the background class in feature extraction. Defaults to False.
alpha (float) – Weighting factor for combining the base loss and the topology loss (i.e.: base_loss + alpha*topology_loss). Defaults to 0.5.
sigmoid (bool) – If True, applies a sigmoid activation to the input before computing the loss. Sigmoid is not applied before passing it to a custom base loss function. Defaults to False.
softmax (bool) – If True, applies a softmax activation to the input before computing the loss. Softmax is not applied before passing it to a custom base loss function. Defaults to False.
use_base_loss (bool) – If False, the loss only consists of the topology component. The base_loss and alpha will be ignored if this flag is set to false. Defaults to True.
base_loss (_Loss, optional) – The base loss function to be used alongside the topology loss. Defaults to None, meaning a standard cross-entropy loss will be used.
- Raises:
ValueError – If more than one of [sigmoid, softmax] is set to True.
- forward(input: torch.Tensor, target: torch.Tensor) torch.Tensor [source]
Calculates the forward pass of the Mosin Loss.
- Parameters:
input (Tensor) – Input tensor of shape (batch_size, num_classes, H, W).
target (Tensor) – Target tensor of shape (batch_size, num_classes, H, W).
- Returns:
The calculated betti matching loss.
- Return type:
Tensor
- Raises:
ValueError – If the shape of the ground truth is different from the input shape.
ValueError – If softmax=True and the number of channels for the prediction is 1.
ValueError – If the input dimension is smaller than 32x32.