Skip to contents

Define a UNet architecture for geospatial semantic segmentation.

Usage

defineUNet(
  inChn = 3,
  nCls = 3,
  actFunc = "relu",
  useAttn = FALSE,
  useSE = FALSE,
  useRes = FALSE,
  useASPP = FALSE,
  useDS = FALSE,
  enChn = c(16, 32, 64, 128),
  dcChn = c(128, 64, 32, 16),
  btnChn = 256,
  dilRates = c(1, 2, 4, 8, 16),
  dilChn = c(16, 16, 16, 16, 16),
  negative_slope = 0.01,
  seRatio = 8
)

Arguments

inChn

Number of channels, bands, or predictor variables in the input image or raster data. Default is 3.

nCls

Number of classes being differentiated. For a binary classification, this can be either 1 or 2. If 2, the problem is treated as a multiclass problem, and a multiclass loss metric should be used. Default is 3.

actFunc

Defines activation function to use throughout the network. "relu" = rectified linear unit (ReLU); "lrelu" = leaky ReLU; "swish" = swish. Default is "relu".

useAttn

TRUE or FALSE. Whether to add attention gates along the skip connections. Default is FALSE or no attention gates are added.

useSE

TRUE or FALSE. Whether or not to include squeeze and excitation modules in the encoder. Default is FALSE or no squeeze and excitation modules are used.

useRes

TRUE or FALSE. Whether to include residual connections in the encoder, decoder, and bottleneck/ASPP module blocks. Default is FALSE or no residual connections are included.

useASPP

TRUE or FALSE. Whether to use an ASPP module as the bottleneck as opposed to a double convolution operation. Default is FALSE or the ASPP module is not used as the bottleneck.

useDS

TRUE or FALSE. Whether or not to use deep supervision. If TRUE, four predictions are made, one at each decoder block resolution, and the predictions are returned as a list object containing the 4 predictions. If FALSE, only the final prediction at the original resolution is returned. Default is FALSE or deep supervision is not implemented.

enChn

Vector of 4 integers defining the number of output feature maps for each of the four encoder blocks. Default is 16, 32, 64, and 128.

dcChn

Vector of 4 integers defining the number of output feature maps for each of the 4 decoder blocks. Default is 128, 64, 32, and 16.

btnChn

Number of output feature maps from the bottleneck block. Default is 256.

dilRates

Vector of 5 values specifying the dilation rates used in the ASPP module. Default is 1, 2, 4, 6, and 16.

dilChn

Vector of 5 values specifying the number of channels to produce at each dilation rate within the ASPP module. Default is 16 for each dilation rate or 80 channels overall.

negative_slope

If actFunc = "lrelu", specifies the negative slope term to use. Default is 0.01.

seRatio

Ratio to use in squeeze and excitation module. The default is 8.

Value

Unet model instance as torch nnn_module

Details

Define a UNet architecture with 4 blocks in the encoder, a bottleneck block, and 4 blocks in the decoder. UNet can accept a variable number of input channels, and the user can define the number of feature maps produced in each encoder and decoder block and the bottleneck. Users can also choose to (1) replace all ReLU activation functions with leaky ReLU or swish, (2) implement attention gates along the skip connections, (3) implement squeeze and excitation modules within the encoder blocks, (4) add residual connections within all blocks, (5) replace the bottleneck with a modified atrous spatial pyramid pooling (ASPP) module, and/or (6) implement deep supervision using predictions generated at each stage in the decoder.

Examples

require(torch)
# example code
#Generate example data as torch tensor
tensorIn <- torch::torch_rand(c(12,4,128,128))

 #Instantiate model
 model <- defineUNet(inChn = 4,
                    nCls = 3,
                    actFunc = "lrelu",
                    useAttn = TRUE,
                    useSE = TRUE,
                    useRes = TRUE,
                    useASPP = TRUE,
                    useDS = TRUE,
                    enChn = c(16,32,64,128),
                    dcChn = c(128,64,32,16),
                    btnChn = 256,
                    dilRates=c(1,2,4,8,16),
                    dilChn=c(16,16,16,16,16),
                    negative_slope = 0.01,
                    seRatio=8)

 #Predict data with model
 pred <- model(tensorIn)