Define a loss for geospatial semantic segmentation using a modified unified focal loss framework as a subclass of torch::nn_module() when using deep supervision.
nCls = 3,
dsWghts = c(0.6, 0.2, 0.1, 0.1),
lambda = 0.5,
gamma = 0.5,
delta = 0.6,
smooth = 1e-08,
zeroStart = TRUE,
clsWghtsDist = 1,
clsWghtsReg = 1,
useLogCosH = FALSE,
device = "cuda"
- nCls
Number of classes being differentiated.
- dsWghts
Vector of 4 weights. Weights to apply to the losses calculated at each spatial resolution when using deep supervision. The default is c(.6, .2, .1, .1) where larger weights are placed on the results at a higher spatial resolution.
- lambda
Term used to control the relative weighting of the distribution- and region-based losses. Default is 0.5, or equal weighting between the losses. If lambda = 1, only the distribution- based loss is considered. If lambda = 0, only the region-based loss is considered. Values between 0.5 and 1 put more weight on the distribution-based loss while values between 0 and 0.5 put more weight on the region-based loss.
- gamma
Parameter that controls weighting applied to difficult-to-predict pixels (for distribution-based losses) or difficult-to-predict classes (for region-based losses). Smaller values increase the weight applied to difficult samples or classes. Default is 1, or no focal weighting is applied. Value must be less than or equal to 1 and larger than 0.
- delta
Parameter that controls the relative weightings of false positive and false negative errors for each class. Different weightings can be provided for each class. The default is 0.6, which results in prioritizing false negative errors relative to false positive errors.
- smooth
Smoothing factor to avoid divide-by-zero errors and provide numeric stability. Default is 1e-8. Recommend using the default.
- zeroStart
TRUE or FALSE. If class indices start at 0 as opposed to 1, this should be set to TRUE. This is required to implement one-hot encoding since R starts indexing at 1. Default is TRUE.
- clsWghtsDist
Vector of class weights for use in calculating a weighted version of the CE loss. Default is for all classes to be equally weighted.
- clsWghtsReg
Vector of class weights for use in calculating a weighted version of the region-based loss. Default is for all classes to be equally weighted.
- useLogCosH
TRUE or FALSE. Whether or not to apply a logCosH transformation to the region-based loss. Default is FALSE.
- device
Define device being used for computation. Define using torch_device().
Implementation of modified version of the unified focal loss after:
Yeung, M., Sala, E., Schönlieb, C.B. and Rundo, L., 2022. Unified focal loss: Generalising Dice and cross entropy-based losses to handle class imbalanced medical image segmentation. Computerized Medical Imaging and Graphics, 95, p.102026.
Modifications include (1) allowing users to define class weights for both the distribution- based and region-based losses, (2) using class weights as opposed to the symmetric and asymmetric methods implemented by the authors, and (3) including an option to apply a logcosh transform to the region-based loss.
This loss has three key hyperparameters that control its implementation. Lambda controls the relative weight of the distribution- and region-based losses. Default is 0.5, or equal weighting between the losses is applied. If lambda = 1, only the distribution- based loss is considered. If lambda = 0, only the region-based loss is considered. Values between 0.5 and 1 put more weight on the distribution-based loss while values between 0 and 0.5 put more weight on the region-based loss.
Gamma controls the application of focal loss and the application of increased weight to difficult-to-predict pixels (for distribution-based losses) or difficult-to-predict classes (region-based losses). Lower gamma values put increased weight on difficult samples or classes. Using a value of 1 equates to not using a focal adjustment.
The delta term controls the relative weight of false positive and false negative errors for each class. The default is 0.6 for each class, which results in placing a higher weight on false negative as opposed to false positive errors relative to that class.
By adjusting the lambda, gamma, delta, and class weight terms, the user can implement a variety of different loss metrics including cross entropy loss, weighted cross entropy loss, focal cross entropy loss, focal weighted cross entropy loss, Dice loss, focal Dice loss, Tversky loss, and focal Tversky loss.