`device <- torch_device("cuda")`

## Background

**geodl** provides a modified version of the
**unified focal loss** proposed in the following study:

Yeung, M., Sala, E., Schönlieb, C.B. and Rundo, L., 2022. Unified
focal loss: Generalising dice and cross entropy-based losses to handle
class imbalanced medical image segmentation. *Computerized Medical
Imaging and Graphics*, 95, p.102026.

The table below describes the loss parameterization. Our implementation is different from the originally proposed implementation because we do not implement symmetric and asymmetric forms. Instead, we allow the user to define class weights for both the distribution- and region-based loss components.

The *lambda* parameter controls the relative weighting between
the distribution- and region-based losses. When *lambda* = 1,
only the distribution-based loss is used. When *lambda* = 0, only
the region-based loss is used. Values between 0 and 1 yield a loss
metric that incorporate both the distribution- and region-based loss
components. A *lambda* of 0.5 yield equal weighting, values
larger than 0.5 put more weight on the distribution-based loss, and
values lower than 0.5 put more weighting on the region-based loss.

The *gamma* parameter controls the weight applied to
difficult-to-predict samples, defined as samples or classes with low
predicted rescaled logits relative to their correct class. For the
distribution-based loss, focal corrections are implemented
sample-by-sample. For the region-based loss, corrections are implemented
class-by-class during macro-averaging. *gamma* must be larger
than 0 and less than or equal to 1. When *gamma* = 1, no focal
correction is applied. Lower values result in a larger focal
correction.

The *delta* parameter controls the relative weights of FN and
FP errors and should be between 0 and 1. A *delta* of 0.5 places
equal weight on FN and FP errors. Values larger than 0.5 place more
weight on FN in comparison to FP samples while values smaller than 0.5
place more weight on FP samples in comparison to FN samples.

The *clsWghtsDist* parameter controls the relative weights of
classes in the distribution-based loss and is applied sample-by-sample.
The *clsWghtsReg* parameter controls the relative weights of
classes in the region-based loss and are applied to each class when
calculating a macro average. By default, all classes are weighted
equally. If you want to implement class weights, you must provide a
vector of class weights equal in length to the number of classes being
differentiated.

Lastly, the *useLogCosH* parameter determines whether or not
to apply a log cosh transformation to the region-based loss. If it is
set to TRUE, this transformation is applied.

Using different parameterization, users can define a variety of loss metrics including cross entropy (CE) loss, weighted CE loss, focal CE loss, focal weighted CE loss, Dice loss, focal Dice loss, Tversky loss, and focal Tversky loss.

## Examples

We will now demonstrate how different loss metrics can be obtained
using different parameterizations. We first load in example data using
the *rast()* function from **terra** representing
class reference numeric codes (refC) and predicted class logits
(predL).

```
refC <- terra::rast("C:/myFiles/data/metricCheck/multiclass_reference.tif")
predL <- terra::rast("C:/myFiles/data/metricCheck/multiclass_logits.tif")
```

The spatRaster objects are then converted to **torch**
tensors with the correct shape and data type. We simulate a mini-batch
of two samples by concatenating two copies of the tensors.

```
predL <- terra::as.array(predL)
refC <- terra::as.array(refC)
target <- torch::torch_tensor(refC, dtype=torch::torch_long(), device=device)
pred <- torch::torch_tensor(predL, dtype=torch::torch_float32(), device=device)
target <- target$permute(c(3,1,2))
pred <- pred$permute(c(3,1,2))
target <- target$unsqueeze(1)
pred <- pred$unsqueeze(1)
target <- torch::torch_cat(list(target, target), dim=1)
pred <- torch::torch_cat(list(pred, pred), dim=1)
```

### Example 1: Dice Loss

The **Dice loss** is obtained by setting the
*lambda* parameter to 0, the *gamma* parameter to 1, and
the *delta* parameter to 0.5. This results in only the
region-based loss being considered, no focal correction being applied,
and equal weighting between FN and FP errors.

```
<- defineUnifiedFocalLoss(nCls=5,
myDiceLoss lambda=0, #Only use region-based loss
gamma= 1,
delta= 0.5, #Equal weights for FP and FN
smooth = 1e-8,
zeroStart=TRUE,
clsWghtsDist=1,
clsWghtsReg=1,
useLogCosH =FALSE,
device=device)
myDiceLoss(pred=pred,
target=target)
torch_tensor0.17861
[ CUDAFloatType{} ]
```

### #Example 2: Tversky Loss

The **Tversky Loss** can be obtained using the same
settings as those used for the **Dice loss** except that
the *delta* parameter must be set to a value other than 0.5 so
that different weights are applied to FN and FP errors. In the example,
we use a weighting of 0.6, which places more weight on FN errors
relative to FP errors. Setting *gamma* to a value lower than 1
results in a **focal Tversky loss**.

Note that we regenerate the tensors so that the computational graphs are re-initialized.

```
<- torch::torch_tensor(refC, dtype=torch::torch_long(), device=device)
target <- torch::torch_tensor(predL, dtype=torch::torch_float32(), device=device)
pred <- target$permute(c(3,1,2))
target <- pred$permute(c(3,1,2))
pred
<- target$unsqueeze(1)
target <- pred$unsqueeze(1)
pred <- torch::torch_cat(list(target, target), dim=1)
target <- torch::torch_cat(list(pred, pred), dim=1)
pred
<- defineUnifiedFocalLoss(nCls=5,
myTverskyLoss lambda=0, #Only use region-based loss
gamma= 1,
delta= 0.6, #FN weighted higher than FP
smooth = 1e-8,
zeroStart=TRUE,
clsWghtsDist=1,
clsWghtsReg=1,
useLogCosH =FALSE,
device=device)
myTverskyLoss(pred=pred,
target=target)
torch_tensor0.177986
[ CUDAFloatType{} ]
```

### Example 3: Cross Entropy (CE) Loss

The **cross entropy** (**CE**) loss is
obtained by setting *lambda* to 1, so that only the
distribution-based loss is considered, and setting *gamma* to 1,
so that no focal correction is applied. Setting *gamma* to a
value lower than 1 results in a **focal CE loss**.

```
<- torch::torch_tensor(refC, dtype=torch::torch_long(), device=device)
target <- torch::torch_tensor(predL, dtype=torch::torch_float32(), device=device)
pred <- target$permute(c(3,1,2))
target <- pred$permute(c(3,1,2))
pred
<- target$unsqueeze(1)
target <- pred$unsqueeze(1)
pred <- torch::torch_cat(list(target, target), dim=1)
target <- torch::torch_cat(list(pred, pred), dim=1)
pred
<- defineUnifiedFocalLoss(nCls=5,
myCELoss lambda=1, #Only use distribution-based loss
gamma= 1,
delta= 0.5,
smooth = 1e-8,
zeroStart=TRUE,
clsWghtsDist=1,
clsWghtsReg=1,
useLogCosH =FALSE,
device=device)
myCELoss(pred=pred,
target=target)
torch_tensor1.58689
[ CUDAFloatType{} ]
```

### Example 4: Combo-Loss

A **combo-loss** can be obtained by setting
*lambda* to a value between 0 and 1. In the example, we have used
0.5, which results in equal weights being applied to the distribution-
and region-based losses. We also apply a focal correction using a
*gamma* of 0.8 and weight FN errors higher than FP errors by
using a *delta* of 0.6. The result is a combination of the
**focal CE** and **focal Tversky** losses.

```
<- torch::torch_tensor(refC, dtype=torch::torch_long(), device=device)
target <- torch::torch_tensor(predL, dtype=torch::torch_float32(), device=device)
pred <- target$permute(c(3,1,2))
target <- pred$permute(c(3,1,2))
pred
<- target$unsqueeze(1)
target <- pred$unsqueeze(1)
pred <- torch::torch_cat(list(target, target), dim=1)
target <- torch::torch_cat(list(pred, pred), dim=1)
pred
<- defineUnifiedFocalLoss(nCls=5,
myComboLoss lambda=.5, #Use both distribution and region-based losses
gamma= 0.8, #Apply a focal adjustment
delta= 0.6, #Weight FN higher than FP
smooth = 1e-8,
zeroStart=TRUE,
clsWghtsDist=1,
clsWghtsReg=1,
useLogCosH =FALSE,
device=device)
myComboLoss(pred=pred,
target=target)
torch_tensor0.9143
1} ] [ CUDAFloatType{
```