Skip to contents

Define a UNet architecture for geospatial semantic segmentation with a MobileNet-v2 backbone.

Usage

defineMobileUNet(
  nCls = 3,
  pretrainedEncoder = TRUE,
  freezeEncoder = TRUE,
  actFunc = "relu",
  useAttn = FALSE,
  useDS = FALSE,
  dcChn = c(256, 128, 64, 32, 16),
  negative_slope = 0.01
)

Arguments

nCls

Number of classes being differentiated. For a binary classification, this can be either 1 or 2. If 2, the problem is treated as a multiclass problem, and a multiclass loss metric should be used. Default is 3.

pretrainedEncoder

TRUE or FALSE. Whether or not to initialized using pre-trained ImageNet weights for the MobileNet-v2 encoder. Default is TRUE.

freezeEncoder

TRUE or FALSE. Whether or not to freeze the encoder during training. The default is TRUE. If TRUE, only the decoder component is trained.

actFunc

Defines activation function to use throughout the network (note that MobileNet-v2 layers are not impacted). "relu" = rectified linear unit (ReLU); "lrelu" = leaky ReLU; "swish" = swish. Default is "relu".

useAttn

TRUE or FALSE. Whether to add attention gates along the skip connections. Default is FALSE or no attention gates are added.

useDS

TRUE or FALSE. Whether or not to use deep supervision. If TRUE, four predictions are made, one at each of the four largest decoder block resolutions, and the predictions are returned as a list object containing the 4 predictions. If FALSE, only the final prediction at the original resolution is returned. Default is FALSE or deep supervision is not implemented.

dcChn

Vector of 4 integers defining the number of output feature maps for each of the 4 decoder blocks. Default is 128, 64, 32, and 16.

negative_slope

If actFunc = "lrelu", specifies the negative slope term to use. Default is 0.01.

Value

ModileUNet model instance as torch nn_module

Details

Define a UNet architecture with a MobileNet-v2 backbone or encoder. This UNet implementation was inspired by a blog post by Sigrid Keydana available here. This architecture has 6 blocks in the encoder (including the bottleneck) and 5 blocks in the decoder. The user is able to implement deep supervision (useDS = TRUE) and attention gates along the skip connections (useAttn = TRUE). This model requires three input bands or channels.

Examples

require(torch)
#Generate example data as torch tensor
tensorIn <- torch::torch_rand(c(12,3,128,128))

#Instantiate model
model <- defineMobileUNet(nCls = 3,
                          pretrainedEncoder = FALSE,
                          freezeEncoder = FALSE,
                          actFunc = "relu",
                          useAttn = TRUE,
                          useDS = TRUE,
                          dcChn = c(256,128,64,32,16),
                          negative_slope = 0.01)

pred <- model(tensorIn)