Skip to contents

Define a loss for semantic segmentation using a modified unified focal loss framework as a subclass of torch::nn_module()

Usage

defineUnifiedFocalLoss(
  nCls = 3,
  lambda = 0.5,
  gamma = 0.5,
  delta = 0.6,
  smooth = 1e-08,
  zeroStart = TRUE,
  clsWghtsDist = 1,
  clsWghtsReg = 1,
  useLogCosH = FALSE,
  device = "cuda"
)

Arguments

nCls

Number of classes being differentiated.

lambda

Term used to control the relative weighting of the distribution- and region-based losses. Default is 0.5, or equal weighting between the losses. If lambda = 1, only the distribution- based loss is considered. If lambda = 0, only the region-based loss is considered. Values between 0.5 and 1 put more weight on the distribution-based loss while values between 0 and 0.5 put more weight on the region-based loss.

gamma

Parameter that controls weighting applied to difficult-to-predict pixels (for distribution-based losses) or difficult-to-predict classes (for region-based losses). Smaller values increase the weight applied to difficult samples or classes. Default is 1, or no focal weighting is applied. Value must be less than or equal to 1 and larger than 0.

delta

Parameter that controls the relative weightings of false positive and false negative errors for each class. Different weightings can be provided for each class. The default is 0.6, which results in prioritizing false negative errors relative to false positive errors.

smooth

Smoothing factor to avoid divide-by-zero errors and provide numeric stability. Default is 1e-8. Recommend using the default.

zeroStart

TRUE or FALSE. If class indices start at 0 as opposed to 1, this should be set to TRUE. This is required to implement one-hot encoding since R starts indexing at 1. Default is TRUE.

clsWghtsDist

Vector of class weights for use in calculating a weighted version of the CE loss. Default is for all classes to be equally weighted.

clsWghtsReg

Vector of class weights for use in calculating a weighted version of the region-based loss. Default is for all classes to be equally weighted.

useLogCosH

TRUE or FALSE. Whether or not to apply a logCosH transformation to the region-based loss. Default is FALSE.

device

Define device being used for computation. Define using torch_device().

Value

Loss metric for use in training process.

Details

Implementation of modified version of the unified focal loss after:

Yeung, M., Sala, E., Schönlieb, C.B. and Rundo, L., 2022. Unified focal loss: Generalising Dice and cross entropy-based losses to handle class imbalanced medical image segmentation. Computerized Medical Imaging and Graphics, 95, p.102026.

Modifications include (1) allowing users to define class weights for both the distribution- based and region-based losses, (2) using class weights as opposed to the symmetric and asymmetric methods implemented by the authors, and (3) including an option to apply a logcosh transform to the region-based loss.

This loss has three key hyperparameters that control its implementation. Lambda controls the relative weight of the distribution- and region-based losses. Default is 0.5, or equal weighting between the losses is applied. If lambda = 1, only the distribution- based loss is considered. If lambda = 0, only the region-based loss is considered. Values between 0.5 and 1 put more weight on the distribution-based loss while values between 0 and 0.5 put more weight on the region-based loss.

Gamma controls the application of focal loss and the application of increased weight to difficult-to-predict pixels (for distribution-based losses) or difficult-to-predict classes (region-based losses). Lower gamma values put increased weight on difficult samples or classes. Using a value of 1 equates to not using a focal adjustment.

The delta term controls the relative weight of false positive and false negative errors for each class. The default is 0.6 for each class, which results in placing a higher weight on false negative as opposed to false positive errors relative to that class.

By adjusting the lambda, gamma, delta, and class weight terms, the user can implement a variety of different loss metrics including cross entropy loss, weighted cross entropy loss, focal cross entropy loss, focal weighted cross entropy loss, Dice loss, focal Dice loss, Tversky loss, and focal Tversky loss.

Examples

library(terra)
library(torch)
#Generate example data as SpatRasters
ref <- terra::rast(matrix(sample(c(1, 2, 3), 625, replace=TRUE), nrow=25, ncol=25))
pred1 <- terra::rast(matrix(sample(c(1:150), 625, replace=TRUE), nrow=25, ncol=25))
pred2 <- terra::rast(matrix(sample(c(1:150), 625, replace=TRUE), nrow=25, ncol=25))
pred3 <- terra::rast(matrix(sample(c(1:150), 625, replace=TRUE), nrow=25, ncol=25))
pred <- c(pred2, pred2, pred3)

#Convert SpatRaster to array
ref <- terra::as.array(ref)
pred <- terra::as.array(pred)

#Convert arrays to tensors and reshape
ref <- torch::torch_tensor(ref, dtype=torch::torch_long())
pred <- torch::torch_tensor(pred, dtype=torch::torch_float32())
ref <- ref$permute(c(3,1,2))
pred <- pred$permute(c(3,1,2))

#Add mini-batch dimension
ref <- ref$unsqueeze(1)
pred <- pred$unsqueeze(1)

#Duplicate tensors to have a batch of two
ref <- torch::torch_cat(list(ref, ref), dim=1)
pred <- torch::torch_cat(list(pred, pred), dim=1)

#Instantiate loss metric
myDiceLoss <- defineUnifiedFocalLoss(nCls=3,
                                    lambda=0, #Only use region-based loss
                                    gamma= 1,
                                    delta= 0.5, #Equal weights for FP and FN
                                    smooth = 1e-8,
                                    zeroStart=FALSE,
                                    clsWghtsDist=1,
                                    clsWghtsReg=1,
                                    useLogCosH =FALSE,
                                    device='cpu')
#Calculate loss
myDiceLoss(pred, ref)
#> torch_tensor
#> 0.667279
#> [ CPUFloatType{} ]