library(geodl)
library(torch)
Unified Focal Loss Framework
Unified Focal Loss Framework
<- torch_device("cuda") device
Background
geodl provides a modified version of the unified focal loss proposed in the following study:
Yeung, M., Sala, E., Schönlieb, C.B. and Rundo, L., 2022. Unified focal loss: Generalising dice and cross entropy-based losses to handle class imbalanced medical image segmentation. Computerized Medical Imaging and Graphics, 95, p.102026.
The table below describes the loss parameterization. Our implementation is different from the originally proposed implementation because we do not implement symmetric and asymmetric forms. Instead, we allow the user to define class weights for both the distribution- and region-based loss components.
The lambda parameter controls the relative weighting between the distribution- and region-based losses. When lambda = 1, only the distribution-based loss is used. When lambda = 0, only the region-based loss is used. Values between 0 and 1 yield a loss metric that incorporate both the distribution- and region-based loss components. A lambda of 0.5 yield equal weighting, values larger than 0.5 put more weight on the distribution-based loss, and values lower than 0.5 put more weighting on the region-based loss.
The gamma parameter controls the weight applied to difficult-to-predict samples, defined as samples or classes with low predicted rescaled logits relative to their correct class. For the distribution-based loss, focal corrections are implemented sample-by-sample. For the region-based loss, corrections are implemented class-by-class during macro-averaging. gamma must be larger than 0 and less than or equal to 1. When gamma = 1, no focal correction is applied. Lower values result in a larger focal correction.
The delta parameter controls the relative weights of FN and FP errors and should be between 0 and 1. A delta of 0.5 places equal weight on FN and FP errors. Values larger than 0.5 place more weight on FN in comparison to FP samples while values smaller than 0.5 place more weight on FP samples in comparison to FN samples.
The clsWghtsDist parameter controls the relative weights of classes in the distribution-based loss and is applied sample-by-sample. The clsWghtsReg parameter controls the relative weights of classes in the region-based loss and are applied to each class when calculating a macro average. By default, all classes are weighted equally. If you want to implement class weights, you must provide a vector of class weights equal in length to the number of classes being differentiated.
Lastly, the useLogCosH parameter determines whether or not to apply a log cosh transformation to the region-based loss. If it is set to TRUE, this transformation is applied.
Using different parameterization, users can define a variety of loss metrics including cross entropy (CE) loss, weighted CE loss, focal CE loss, focal weighted CE loss, Dice loss, focal Dice loss, Tversky loss, and focal Tversky loss.
Examples
We will now demonstrate how different loss metrics can be obtained using different parameterizations. We first load in example data using the rast() function from terra representing class reference numeric codes (refC) and predicted class logits (predL).
<- terra::rast("data/geodl/metricCheck/multiclass_reference.tif")
refC <- terra::rast("data/geodl/metricCheck/multiclass_logits.tif") predL
The spatRaster objects are then converted to torch tensors with the correct shape and data type. We simulate a mini-batch of two samples by concatenating two copies of the tensors.
<- terra::as.array(predL)
predL <- terra::as.array(refC)
refC
<- torch::torch_tensor(refC, dtype=torch::torch_long(), device=device)
target <- torch::torch_tensor(predL, dtype=torch::torch_float32(), device=device)
pred <- target$permute(c(3,1,2))
target <- pred$permute(c(3,1,2))
pred
<- target$unsqueeze(1)
target <- pred$unsqueeze(1)
pred
<- torch::torch_cat(list(target, target), dim=1)
target <- torch::torch_cat(list(pred, pred), dim=1) pred
Example 1: Dice Loss
The Dice loss is obtained by setting the lambda parameter to 0, the gamma parameter to 1, and the delta parameter to 0.5. This results in only the region-based loss being considered, no focal correction being applied, and equal weighting between FN and FP errors.
<- defineUnifiedFocalLoss(nCls=5,
myDiceLoss lambda=0, #Only use region-based loss
gamma= 1,
delta= 0.5, #Equal weights for FP and FN
smooth = 1e-8,
zeroStart=TRUE,
clsWghtsDist=1,
clsWghtsReg=1,
useLogCosH =FALSE,
device=device)
myDiceLoss(pred=pred,
target=target)
torch_tensor
0.17861
[ CUDAFloatType{} ]
#Example 2: Tversky Loss
The Tversky Loss can be obtained using the same settings as those used for the Dice loss except that the delta parameter must be set to a value other than 0.5 so that different weights are applied to FN and FP errors. In the example, we use a weighting of 0.6, which places more weight on FN errors relative to FP errors. Setting gamma to a value lower than 1 results in a focal Tversky loss.
Note that we regenerate the tensors so that the computational graphs are re-initialized.
<- torch::torch_tensor(refC, dtype=torch::torch_long(), device=device)
target <- torch::torch_tensor(predL, dtype=torch::torch_float32(), device=device)
pred <- target$permute(c(3,1,2))
target <- pred$permute(c(3,1,2))
pred
<- target$unsqueeze(1)
target <- pred$unsqueeze(1)
pred <- torch::torch_cat(list(target, target), dim=1)
target <- torch::torch_cat(list(pred, pred), dim=1)
pred
<- defineUnifiedFocalLoss(nCls=5,
myTverskyLoss lambda=0, #Only use region-based loss
gamma= 1,
delta= 0.6, #FN weighted higher than FP
smooth = 1e-8,
zeroStart=TRUE,
clsWghtsDist=1,
clsWghtsReg=1,
useLogCosH =FALSE,
device=device)
myTverskyLoss(pred=pred,
target=target)
torch_tensor
0.177986
[ CUDAFloatType{} ]
Example 3: Cross Entropy (CE) Loss
The cross entropy (CE) loss is obtained by setting lambda to 1, so that only the distribution-based loss is considered, and setting gamma to 1, so that no focal correction is applied. Setting gamma to a value lower than 1 results in a focal CE loss.
<- torch::torch_tensor(refC, dtype=torch::torch_long(), device=device)
target <- torch::torch_tensor(predL, dtype=torch::torch_float32(), device=device)
pred <- target$permute(c(3,1,2))
target <- pred$permute(c(3,1,2))
pred
<- target$unsqueeze(1)
target <- pred$unsqueeze(1)
pred <- torch::torch_cat(list(target, target), dim=1)
target <- torch::torch_cat(list(pred, pred), dim=1)
pred
<- defineUnifiedFocalLoss(nCls=5,
myCELoss lambda=1, #Only use distribution-based loss
gamma= 1,
delta= 0.5,
smooth = 1e-8,
zeroStart=TRUE,
clsWghtsDist=1,
clsWghtsReg=1,
useLogCosH =FALSE,
device=device)
myCELoss(pred=pred,
target=target)
torch_tensor
1.58689
[ CUDAFloatType{} ]
Example 4: Combo-Loss
A combo-loss can be obtained by setting lambda to a value between 0 and 1. In the example, we have used 0.5, which results in equal weights being applied to the distribution- and region-based losses. We also apply a focal correction using a gamma of 0.8 and weight FN errors higher than FP errors by using a delta of 0.6. The result is a combination of the focal CE and focal Tversky losses.
<- torch::torch_tensor(refC, dtype=torch::torch_long(), device=device)
target <- torch::torch_tensor(predL, dtype=torch::torch_float32(), device=device)
pred <- target$permute(c(3,1,2))
target <- pred$permute(c(3,1,2))
pred
<- target$unsqueeze(1)
target <- pred$unsqueeze(1)
pred <- torch::torch_cat(list(target, target), dim=1)
target <- torch::torch_cat(list(pred, pred), dim=1)
pred
<- defineUnifiedFocalLoss(nCls=5,
myComboLoss lambda=.5, #Use both distribution and region-based losses
gamma= 0.8, #Apply a focal adjustment
delta= 0.6, #Weight FN higher than FP
smooth = 1e-8,
zeroStart=TRUE,
clsWghtsDist=1,
clsWghtsReg=1,
useLogCosH =FALSE,
device=device)
myComboLoss(pred=pred,
target=target)
torch_tensor
0.9143
[ CUDAFloatType{1} ]