UNet Encoders
UNet Encoders
Introduction
The backbone or encoder component of semantic segmentation models serve the same purpose as the convolutional layers in a CNN architecture designed for scene labeling tasks: they characterize spatial patterns at varying spatial scales. As a result, the encoder component of semantic segmentation architectures, such as UNets, can be augmented to use a variety of pre-defined CNN architectures, such as ResNets and VGGNets. Since these common architectures have been trained using large datasets, such as ImageNet, this allows for pre-trained weights to also be incorporated into the backbone or encoder component of semantic segmentation models. This component of the network can then either be frozen or updated during the training process. Such a use of transfer learning may allow for training models with less training data and/or for few epochs to obtain adequate results.
In this short module, I will demonstrate augmenting UNet. The first example will use a VGGNet-16 architecture as the model backbone, and the second example will use a ResNet backbone. I will only define the model architectures and summarize them using torchinfo. I will not train the algorithms. However, the methods used to train the UNet in the Train a UNet module could be applied to these architectures if desired.
Since I am defining architectures, I need to import torch and torch.nn. The torchinfo package is used to summarize the model architecture while the backbones will be accessed using the torchvision implementation. Lastly, I define the GPU as the device.
import torch
import torch.nn as nn
from torchinfo import summary
import torchvision.models
= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device print(device)
cuda:0
VGGNet-16 Encoder
I begin by creating a VGGNet backbone. The first step is to instantiate the backbone, as implemented in the torchvision.models subpackage. I use a modified version of VGGNet-16 that incorporates batch normalization using the vgg16_bn() function. Since I will not actually train the algorithm, I do not download the pre-trained weights.
= torchvision.models.vgg16_bn(pretrained=False).to(device) vgg16
C:\Users\vidcg\ANACON~1\envs\torchENV\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
C:\Users\vidcg\ANACON~1\envs\torchENV\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
warnings.warn(msg)
I next obtain a summary of the VGGNet-16 architecture using an input array size of (12, 3,256,256). In order to use this architecture within UNet, I will need to be able to obtain the outputs before each max pooling operation so that the results can be passed to the decoder via the skip connections. The array sizes in the spatial dimensions being passed through the skip connections must be the same as those to which they are being concatenated.
12,3,256,256)) summary(vgg16, (
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
VGG [12, 1000] --
├─Sequential: 1-1 [12, 512, 8, 8] --
│ └─Conv2d: 2-1 [12, 64, 256, 256] 1,792
│ └─BatchNorm2d: 2-2 [12, 64, 256, 256] 128
│ └─ReLU: 2-3 [12, 64, 256, 256] --
│ └─Conv2d: 2-4 [12, 64, 256, 256] 36,928
│ └─BatchNorm2d: 2-5 [12, 64, 256, 256] 128
│ └─ReLU: 2-6 [12, 64, 256, 256] --
│ └─MaxPool2d: 2-7 [12, 64, 128, 128] --
│ └─Conv2d: 2-8 [12, 128, 128, 128] 73,856
│ └─BatchNorm2d: 2-9 [12, 128, 128, 128] 256
│ └─ReLU: 2-10 [12, 128, 128, 128] --
│ └─Conv2d: 2-11 [12, 128, 128, 128] 147,584
│ └─BatchNorm2d: 2-12 [12, 128, 128, 128] 256
│ └─ReLU: 2-13 [12, 128, 128, 128] --
│ └─MaxPool2d: 2-14 [12, 128, 64, 64] --
│ └─Conv2d: 2-15 [12, 256, 64, 64] 295,168
│ └─BatchNorm2d: 2-16 [12, 256, 64, 64] 512
│ └─ReLU: 2-17 [12, 256, 64, 64] --
│ └─Conv2d: 2-18 [12, 256, 64, 64] 590,080
│ └─BatchNorm2d: 2-19 [12, 256, 64, 64] 512
│ └─ReLU: 2-20 [12, 256, 64, 64] --
│ └─Conv2d: 2-21 [12, 256, 64, 64] 590,080
│ └─BatchNorm2d: 2-22 [12, 256, 64, 64] 512
│ └─ReLU: 2-23 [12, 256, 64, 64] --
│ └─MaxPool2d: 2-24 [12, 256, 32, 32] --
│ └─Conv2d: 2-25 [12, 512, 32, 32] 1,180,160
│ └─BatchNorm2d: 2-26 [12, 512, 32, 32] 1,024
│ └─ReLU: 2-27 [12, 512, 32, 32] --
│ └─Conv2d: 2-28 [12, 512, 32, 32] 2,359,808
│ └─BatchNorm2d: 2-29 [12, 512, 32, 32] 1,024
│ └─ReLU: 2-30 [12, 512, 32, 32] --
│ └─Conv2d: 2-31 [12, 512, 32, 32] 2,359,808
│ └─BatchNorm2d: 2-32 [12, 512, 32, 32] 1,024
│ └─ReLU: 2-33 [12, 512, 32, 32] --
│ └─MaxPool2d: 2-34 [12, 512, 16, 16] --
│ └─Conv2d: 2-35 [12, 512, 16, 16] 2,359,808
│ └─BatchNorm2d: 2-36 [12, 512, 16, 16] 1,024
│ └─ReLU: 2-37 [12, 512, 16, 16] --
│ └─Conv2d: 2-38 [12, 512, 16, 16] 2,359,808
│ └─BatchNorm2d: 2-39 [12, 512, 16, 16] 1,024
│ └─ReLU: 2-40 [12, 512, 16, 16] --
│ └─Conv2d: 2-41 [12, 512, 16, 16] 2,359,808
│ └─BatchNorm2d: 2-42 [12, 512, 16, 16] 1,024
│ └─ReLU: 2-43 [12, 512, 16, 16] --
│ └─MaxPool2d: 2-44 [12, 512, 8, 8] --
├─AdaptiveAvgPool2d: 1-2 [12, 512, 7, 7] --
├─Sequential: 1-3 [12, 1000] --
│ └─Linear: 2-45 [12, 4096] 102,764,544
│ └─ReLU: 2-46 [12, 4096] --
│ └─Dropout: 2-47 [12, 4096] --
│ └─Linear: 2-48 [12, 4096] 16,781,312
│ └─ReLU: 2-49 [12, 4096] --
│ └─Dropout: 2-50 [12, 4096] --
│ └─Linear: 2-51 [12, 1000] 4,097,000
==========================================================================================
Total params: 138,365,992
Trainable params: 138,365,992
Non-trainable params: 0
Total mult-adds (G): 242.23
==========================================================================================
Input size (MB): 9.44
Forward/backward pass size (MB): 3398.27
Params size (MB): 553.46
Estimated Total Size (MB): 3961.17
==========================================================================================
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
VGG [12, 1000] --
├─Sequential: 1-1 [12, 512, 8, 8] --
│ └─Conv2d: 2-1 [12, 64, 256, 256] 1,792
│ └─BatchNorm2d: 2-2 [12, 64, 256, 256] 128
│ └─ReLU: 2-3 [12, 64, 256, 256] --
│ └─Conv2d: 2-4 [12, 64, 256, 256] 36,928
│ └─BatchNorm2d: 2-5 [12, 64, 256, 256] 128
│ └─ReLU: 2-6 [12, 64, 256, 256] --
│ └─MaxPool2d: 2-7 [12, 64, 128, 128] --
│ └─Conv2d: 2-8 [12, 128, 128, 128] 73,856
│ └─BatchNorm2d: 2-9 [12, 128, 128, 128] 256
│ └─ReLU: 2-10 [12, 128, 128, 128] --
│ └─Conv2d: 2-11 [12, 128, 128, 128] 147,584
│ └─BatchNorm2d: 2-12 [12, 128, 128, 128] 256
│ └─ReLU: 2-13 [12, 128, 128, 128] --
│ └─MaxPool2d: 2-14 [12, 128, 64, 64] --
│ └─Conv2d: 2-15 [12, 256, 64, 64] 295,168
│ └─BatchNorm2d: 2-16 [12, 256, 64, 64] 512
│ └─ReLU: 2-17 [12, 256, 64, 64] --
│ └─Conv2d: 2-18 [12, 256, 64, 64] 590,080
│ └─BatchNorm2d: 2-19 [12, 256, 64, 64] 512
│ └─ReLU: 2-20 [12, 256, 64, 64] --
│ └─Conv2d: 2-21 [12, 256, 64, 64] 590,080
│ └─BatchNorm2d: 2-22 [12, 256, 64, 64] 512
│ └─ReLU: 2-23 [12, 256, 64, 64] --
│ └─MaxPool2d: 2-24 [12, 256, 32, 32] --
│ └─Conv2d: 2-25 [12, 512, 32, 32] 1,180,160
│ └─BatchNorm2d: 2-26 [12, 512, 32, 32] 1,024
│ └─ReLU: 2-27 [12, 512, 32, 32] --
│ └─Conv2d: 2-28 [12, 512, 32, 32] 2,359,808
│ └─BatchNorm2d: 2-29 [12, 512, 32, 32] 1,024
│ └─ReLU: 2-30 [12, 512, 32, 32] --
│ └─Conv2d: 2-31 [12, 512, 32, 32] 2,359,808
│ └─BatchNorm2d: 2-32 [12, 512, 32, 32] 1,024
│ └─ReLU: 2-33 [12, 512, 32, 32] --
│ └─MaxPool2d: 2-34 [12, 512, 16, 16] --
│ └─Conv2d: 2-35 [12, 512, 16, 16] 2,359,808
│ └─BatchNorm2d: 2-36 [12, 512, 16, 16] 1,024
│ └─ReLU: 2-37 [12, 512, 16, 16] --
│ └─Conv2d: 2-38 [12, 512, 16, 16] 2,359,808
│ └─BatchNorm2d: 2-39 [12, 512, 16, 16] 1,024
│ └─ReLU: 2-40 [12, 512, 16, 16] --
│ └─Conv2d: 2-41 [12, 512, 16, 16] 2,359,808
│ └─BatchNorm2d: 2-42 [12, 512, 16, 16] 1,024
│ └─ReLU: 2-43 [12, 512, 16, 16] --
│ └─MaxPool2d: 2-44 [12, 512, 8, 8] --
├─AdaptiveAvgPool2d: 1-2 [12, 512, 7, 7] --
├─Sequential: 1-3 [12, 1000] --
│ └─Linear: 2-45 [12, 4096] 102,764,544
│ └─ReLU: 2-46 [12, 4096] --
│ └─Dropout: 2-47 [12, 4096] --
│ └─Linear: 2-48 [12, 4096] 16,781,312
│ └─ReLU: 2-49 [12, 4096] --
│ └─Dropout: 2-50 [12, 4096] --
│ └─Linear: 2-51 [12, 1000] 4,097,000
==========================================================================================
Total params: 138,365,992
Trainable params: 138,365,992
Non-trainable params: 0
Total mult-adds (G): 242.23
==========================================================================================
Input size (MB): 9.44
Forward/backward pass size (MB): 3398.27
Params size (MB): 553.46
Estimated Total Size (MB): 3961.17
==========================================================================================
It is possible to obtain a list of all of the architecture’s components using the features property of the instantiated VGGNet-16 architecture. Printing the result, you can see that the architecture has 43 components. I will need to divide the model’s components as follows to extract the results at the correct locations in the architecture.
- Layers 0-5 -> 1st skip connection (original array size)
- Layers 6-12 -> 2nd skip connection (original array size/2)
- Layers 13-22 -> 3rd skip connection (original array size/4)
- Layers 23-32 -> 4th skip connection (original array size/8)
- Layers 33-42 -> Bottleneck (original array size/16)
I will need to split the architecture before each max pooling operation so that I can match each encoder step with the associated decoder step that has the same array sizes in the spatial dimensions. The last max pooling layer is not used since I do not want do decrease the array size further. Instead, the data will enter the decoder component and the first 2D transpose convolution operation.
= torchvision.models.vgg16_bn(pretrained=False).features vgg16F
vgg16F
Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(9): ReLU(inplace=True)
(10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(12): ReLU(inplace=True)
(13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(16): ReLU(inplace=True)
(17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(19): ReLU(inplace=True)
(20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(26): ReLU(inplace=True)
(27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(29): ReLU(inplace=True)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(32): ReLU(inplace=True)
(33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(36): ReLU(inplace=True)
(37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(39): ReLU(inplace=True)
(40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(42): ReLU(inplace=True)
(43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
With what we learned about the VGGNet-16 architecture from our explorations above, let’s now incorporate it into a UNet architecture. As I did in the UNet Architecture module, I define double_conv() and up_conv() functions for use in the UNet architecture. I next define the UNet architecture by subclassing nn.Module. Here are the key components.
- The __init__() constructor method defines the following parameters: the number of input channels (nCls), the number of classes being differentiated (nCls), and whether or not to use pre-trained weights.
- The features from the VGGNet-16 architecture are extracted using the vgg16_bn() function from torchvision and the features property.
- A list of output sizes for the decoder are defined as a list.
- The components of the encoder are defined as the appropriate subset of layers from VGGNet-16. The components are combined using nn.Sequential(), and * is used to unpack the list of features. Again, I am breaking the model apart before each max pooling operation so that the array sizes in the spatial dimensions match between each encoder block and associated decoder block.
- I define the bottleneck as the last set of components of the VGGNet-16 model. The last feature is not used since we do not want to apply the last max pooling operation.
- I define the decoder blocks, which each consist of upsampling using 2D transpose convolution and a series of two 2D convolution layers to learn filters.
- The last layer uses 2D convolution with a kernel size of 1x1 and a stride of 1. The number of outputs is equal to the number of classes being differentiated.
- The forward() method defines how data will pass through the architecture. Data first pass through each of the encoder blocks followed by the bottleneck. In each decoder block, upsampling is performed using 2D transpose convolution, the feature maps from the associated encoder block are concatenated, and the features pass through two 2D convolution layers. Lastly, the data pass through the final 2D convolution layer, which has a kernel size of 1x1 and a stride of 1. The output will be logits for each class. In the case of a binary classification, it could return only a logit for the positive case. Probabilities are not returned since a sigmoid or softmax activation is not being applied.
def double_conv(inChannels, outChannels):
return nn.Sequential(
=(3,3), stride=1, padding=1),
nn.Conv2d(inChannels, outChannels, kernel_size
nn.BatchNorm2d(outChannels),=True),
nn.ReLU(inplace=(3,3), stride=1, padding=1),
nn.Conv2d(outChannels, outChannels, kernel_size
nn.BatchNorm2d(outChannels),=True)
nn.ReLU(inplace )
def up_conv(inChannels, outChannels):
return nn.Sequential(
=(2,2), stride=2),
nn.ConvTranspose2d(inChannels, outChannels, kernel_size
nn.BatchNorm2d(outChannels),=True)
nn.ReLU(inplace )
class myUNetVGG16(nn.Module):
def __init__(self, inChn, nCls, useWghts=True):
super().__init__()
self.inChn = inChn
self.nCls = nCls
self.useWghts = useWghts
self.base_model = torchvision.models.vgg16_bn(pretrained=useWghts).features
self.outSizes = [64, 128, 256, 512, 512]
self.encoder1 = nn.Sequential(*self.base_model[:6])
self.encoder2 = nn.Sequential(*self.base_model[6:13])
self.encoder3 = nn.Sequential(*self.base_model[13:23])
self.encoder4 = nn.Sequential(*self.base_model[23:33])
self.bottleneck = nn.Sequential(*self.base_model[33:43])
self.decoder1up = up_conv(self.outSizes[4], 512)
self.decoder1 = double_conv(self.outSizes[3] + 512, 256)
self.decoder2up = up_conv(256, 256)
self.decoder2 = double_conv(self.outSizes[2] + 256, 128)
self.decoder3up = up_conv(128, 128)
self.decoder3 = double_conv(self.outSizes[1] + 128, 64)
self.decoder4up = up_conv(64, 64)
self.decoder4 = double_conv(self.outSizes[0] + 64, 32)
self.classifier = nn.Conv2d(32, nCls, kernel_size=(1,1))
def forward(self, x):
#Encoder
= self.encoder1(x)
encoder1 = self.encoder2(encoder1)
encoder2 = self.encoder3(encoder2)
encoder3 = self.encoder4(encoder3)
encoder4
#Bottleneck
= self.bottleneck(encoder4)
x
#Decoder
= self.decoder1up(x)
x = torch.concat([x, encoder4], dim=1)
x = self.decoder1(x)
x
= self.decoder2up(x)
x = torch.concat([x, encoder3], dim=1)
x = self.decoder2(x)
x
= self.decoder3up(x)
x = torch.concat([x, encoder2], dim=1)
x = self.decoder3(x)
x
= self.decoder4up(x)
x = torch.concat([x, encoder1], dim=1)
x = self.decoder4(x)
x
#Classifier head
= self.classifier(x)
x
return x
I instantiate an instance of the myUNetVGG16() architecture that accepts 3 input channels and outputs 10 class logits. I also initialize the model using the VGGNet-16 pre-trained weights available from torchvision.
Using the summary() function from torchinfo, you can see that the model has over 34 million trainable parameters. There are currently no non-trainable parameters. This is because, even though I downloaded the pre-trained weights, all parameters can still be updated during the learning process. In other words, the model will be initialized using these weights as opposed to random weights, but these layers and associated parameters are still trainable.
= myUNetVGG16(inChn=3, nCls=10, useWghts=True).to(device) model
C:\Users\vidcg\ANACON~1\envs\torchENV\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_BN_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_BN_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
12,3,256,256)) summary(model, (
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
myUNetVGG16 [12, 10, 256, 256] --
├─Sequential: 1-1 [12, 64, 256, 256] --
│ └─Conv2d: 2-1 [12, 64, 256, 256] 1,792
│ └─BatchNorm2d: 2-2 [12, 64, 256, 256] 128
│ └─ReLU: 2-3 [12, 64, 256, 256] --
│ └─Conv2d: 2-4 [12, 64, 256, 256] 36,928
│ └─BatchNorm2d: 2-5 [12, 64, 256, 256] 128
│ └─ReLU: 2-6 [12, 64, 256, 256] --
├─Sequential: 1-2 [12, 128, 128, 128] --
│ └─MaxPool2d: 2-7 [12, 64, 128, 128] --
│ └─Conv2d: 2-8 [12, 128, 128, 128] 73,856
│ └─BatchNorm2d: 2-9 [12, 128, 128, 128] 256
│ └─ReLU: 2-10 [12, 128, 128, 128] --
│ └─Conv2d: 2-11 [12, 128, 128, 128] 147,584
│ └─BatchNorm2d: 2-12 [12, 128, 128, 128] 256
│ └─ReLU: 2-13 [12, 128, 128, 128] --
├─Sequential: 1-3 [12, 256, 64, 64] --
│ └─MaxPool2d: 2-14 [12, 128, 64, 64] --
│ └─Conv2d: 2-15 [12, 256, 64, 64] 295,168
│ └─BatchNorm2d: 2-16 [12, 256, 64, 64] 512
│ └─ReLU: 2-17 [12, 256, 64, 64] --
│ └─Conv2d: 2-18 [12, 256, 64, 64] 590,080
│ └─BatchNorm2d: 2-19 [12, 256, 64, 64] 512
│ └─ReLU: 2-20 [12, 256, 64, 64] --
│ └─Conv2d: 2-21 [12, 256, 64, 64] 590,080
│ └─BatchNorm2d: 2-22 [12, 256, 64, 64] 512
│ └─ReLU: 2-23 [12, 256, 64, 64] --
├─Sequential: 1-4 [12, 512, 32, 32] --
│ └─MaxPool2d: 2-24 [12, 256, 32, 32] --
│ └─Conv2d: 2-25 [12, 512, 32, 32] 1,180,160
│ └─BatchNorm2d: 2-26 [12, 512, 32, 32] 1,024
│ └─ReLU: 2-27 [12, 512, 32, 32] --
│ └─Conv2d: 2-28 [12, 512, 32, 32] 2,359,808
│ └─BatchNorm2d: 2-29 [12, 512, 32, 32] 1,024
│ └─ReLU: 2-30 [12, 512, 32, 32] --
│ └─Conv2d: 2-31 [12, 512, 32, 32] 2,359,808
│ └─BatchNorm2d: 2-32 [12, 512, 32, 32] 1,024
│ └─ReLU: 2-33 [12, 512, 32, 32] --
├─Sequential: 1-5 [12, 512, 16, 16] --
│ └─MaxPool2d: 2-34 [12, 512, 16, 16] --
│ └─Conv2d: 2-35 [12, 512, 16, 16] 2,359,808
│ └─BatchNorm2d: 2-36 [12, 512, 16, 16] 1,024
│ └─ReLU: 2-37 [12, 512, 16, 16] --
│ └─Conv2d: 2-38 [12, 512, 16, 16] 2,359,808
│ └─BatchNorm2d: 2-39 [12, 512, 16, 16] 1,024
│ └─ReLU: 2-40 [12, 512, 16, 16] --
│ └─Conv2d: 2-41 [12, 512, 16, 16] 2,359,808
│ └─BatchNorm2d: 2-42 [12, 512, 16, 16] 1,024
│ └─ReLU: 2-43 [12, 512, 16, 16] --
├─Sequential: 1-6 [12, 512, 32, 32] --
│ └─ConvTranspose2d: 2-44 [12, 512, 32, 32] 1,049,088
│ └─BatchNorm2d: 2-45 [12, 512, 32, 32] 1,024
│ └─ReLU: 2-46 [12, 512, 32, 32] --
├─Sequential: 1-7 [12, 256, 32, 32] --
│ └─Conv2d: 2-47 [12, 256, 32, 32] 2,359,552
│ └─BatchNorm2d: 2-48 [12, 256, 32, 32] 512
│ └─ReLU: 2-49 [12, 256, 32, 32] --
│ └─Conv2d: 2-50 [12, 256, 32, 32] 590,080
│ └─BatchNorm2d: 2-51 [12, 256, 32, 32] 512
│ └─ReLU: 2-52 [12, 256, 32, 32] --
├─Sequential: 1-8 [12, 256, 64, 64] --
│ └─ConvTranspose2d: 2-53 [12, 256, 64, 64] 262,400
│ └─BatchNorm2d: 2-54 [12, 256, 64, 64] 512
│ └─ReLU: 2-55 [12, 256, 64, 64] --
├─Sequential: 1-9 [12, 128, 64, 64] --
│ └─Conv2d: 2-56 [12, 128, 64, 64] 589,952
│ └─BatchNorm2d: 2-57 [12, 128, 64, 64] 256
│ └─ReLU: 2-58 [12, 128, 64, 64] --
│ └─Conv2d: 2-59 [12, 128, 64, 64] 147,584
│ └─BatchNorm2d: 2-60 [12, 128, 64, 64] 256
│ └─ReLU: 2-61 [12, 128, 64, 64] --
├─Sequential: 1-10 [12, 128, 128, 128] --
│ └─ConvTranspose2d: 2-62 [12, 128, 128, 128] 65,664
│ └─BatchNorm2d: 2-63 [12, 128, 128, 128] 256
│ └─ReLU: 2-64 [12, 128, 128, 128] --
├─Sequential: 1-11 [12, 64, 128, 128] --
│ └─Conv2d: 2-65 [12, 64, 128, 128] 147,520
│ └─BatchNorm2d: 2-66 [12, 64, 128, 128] 128
│ └─ReLU: 2-67 [12, 64, 128, 128] --
│ └─Conv2d: 2-68 [12, 64, 128, 128] 36,928
│ └─BatchNorm2d: 2-69 [12, 64, 128, 128] 128
│ └─ReLU: 2-70 [12, 64, 128, 128] --
├─Sequential: 1-12 [12, 64, 256, 256] --
│ └─ConvTranspose2d: 2-71 [12, 64, 256, 256] 16,448
│ └─BatchNorm2d: 2-72 [12, 64, 256, 256] 128
│ └─ReLU: 2-73 [12, 64, 256, 256] --
├─Sequential: 1-13 [12, 32, 256, 256] --
│ └─Conv2d: 2-74 [12, 32, 256, 256] 36,896
│ └─BatchNorm2d: 2-75 [12, 32, 256, 256] 64
│ └─ReLU: 2-76 [12, 32, 256, 256] --
│ └─Conv2d: 2-77 [12, 32, 256, 256] 9,248
│ └─BatchNorm2d: 2-78 [12, 32, 256, 256] 64
│ └─ReLU: 2-79 [12, 32, 256, 256] --
├─Conv2d: 1-14 [12, 10, 256, 256] 330
==========================================================================================
Total params: 20,038,666
Trainable params: 20,038,666
Non-trainable params: 0
Total mult-adds (G): 437.69
==========================================================================================
Input size (MB): 9.44
Forward/backward pass size (MB): 6480.20
Params size (MB): 80.15
Estimated Total Size (MB): 6569.79
==========================================================================================
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
myUNetVGG16 [12, 10, 256, 256] --
├─Sequential: 1-1 [12, 64, 256, 256] --
│ └─Conv2d: 2-1 [12, 64, 256, 256] 1,792
│ └─BatchNorm2d: 2-2 [12, 64, 256, 256] 128
│ └─ReLU: 2-3 [12, 64, 256, 256] --
│ └─Conv2d: 2-4 [12, 64, 256, 256] 36,928
│ └─BatchNorm2d: 2-5 [12, 64, 256, 256] 128
│ └─ReLU: 2-6 [12, 64, 256, 256] --
├─Sequential: 1-2 [12, 128, 128, 128] --
│ └─MaxPool2d: 2-7 [12, 64, 128, 128] --
│ └─Conv2d: 2-8 [12, 128, 128, 128] 73,856
│ └─BatchNorm2d: 2-9 [12, 128, 128, 128] 256
│ └─ReLU: 2-10 [12, 128, 128, 128] --
│ └─Conv2d: 2-11 [12, 128, 128, 128] 147,584
│ └─BatchNorm2d: 2-12 [12, 128, 128, 128] 256
│ └─ReLU: 2-13 [12, 128, 128, 128] --
├─Sequential: 1-3 [12, 256, 64, 64] --
│ └─MaxPool2d: 2-14 [12, 128, 64, 64] --
│ └─Conv2d: 2-15 [12, 256, 64, 64] 295,168
│ └─BatchNorm2d: 2-16 [12, 256, 64, 64] 512
│ └─ReLU: 2-17 [12, 256, 64, 64] --
│ └─Conv2d: 2-18 [12, 256, 64, 64] 590,080
│ └─BatchNorm2d: 2-19 [12, 256, 64, 64] 512
│ └─ReLU: 2-20 [12, 256, 64, 64] --
│ └─Conv2d: 2-21 [12, 256, 64, 64] 590,080
│ └─BatchNorm2d: 2-22 [12, 256, 64, 64] 512
│ └─ReLU: 2-23 [12, 256, 64, 64] --
├─Sequential: 1-4 [12, 512, 32, 32] --
│ └─MaxPool2d: 2-24 [12, 256, 32, 32] --
│ └─Conv2d: 2-25 [12, 512, 32, 32] 1,180,160
│ └─BatchNorm2d: 2-26 [12, 512, 32, 32] 1,024
│ └─ReLU: 2-27 [12, 512, 32, 32] --
│ └─Conv2d: 2-28 [12, 512, 32, 32] 2,359,808
│ └─BatchNorm2d: 2-29 [12, 512, 32, 32] 1,024
│ └─ReLU: 2-30 [12, 512, 32, 32] --
│ └─Conv2d: 2-31 [12, 512, 32, 32] 2,359,808
│ └─BatchNorm2d: 2-32 [12, 512, 32, 32] 1,024
│ └─ReLU: 2-33 [12, 512, 32, 32] --
├─Sequential: 1-5 [12, 512, 16, 16] --
│ └─MaxPool2d: 2-34 [12, 512, 16, 16] --
│ └─Conv2d: 2-35 [12, 512, 16, 16] 2,359,808
│ └─BatchNorm2d: 2-36 [12, 512, 16, 16] 1,024
│ └─ReLU: 2-37 [12, 512, 16, 16] --
│ └─Conv2d: 2-38 [12, 512, 16, 16] 2,359,808
│ └─BatchNorm2d: 2-39 [12, 512, 16, 16] 1,024
│ └─ReLU: 2-40 [12, 512, 16, 16] --
│ └─Conv2d: 2-41 [12, 512, 16, 16] 2,359,808
│ └─BatchNorm2d: 2-42 [12, 512, 16, 16] 1,024
│ └─ReLU: 2-43 [12, 512, 16, 16] --
├─Sequential: 1-6 [12, 512, 32, 32] --
│ └─ConvTranspose2d: 2-44 [12, 512, 32, 32] 1,049,088
│ └─BatchNorm2d: 2-45 [12, 512, 32, 32] 1,024
│ └─ReLU: 2-46 [12, 512, 32, 32] --
├─Sequential: 1-7 [12, 256, 32, 32] --
│ └─Conv2d: 2-47 [12, 256, 32, 32] 2,359,552
│ └─BatchNorm2d: 2-48 [12, 256, 32, 32] 512
│ └─ReLU: 2-49 [12, 256, 32, 32] --
│ └─Conv2d: 2-50 [12, 256, 32, 32] 590,080
│ └─BatchNorm2d: 2-51 [12, 256, 32, 32] 512
│ └─ReLU: 2-52 [12, 256, 32, 32] --
├─Sequential: 1-8 [12, 256, 64, 64] --
│ └─ConvTranspose2d: 2-53 [12, 256, 64, 64] 262,400
│ └─BatchNorm2d: 2-54 [12, 256, 64, 64] 512
│ └─ReLU: 2-55 [12, 256, 64, 64] --
├─Sequential: 1-9 [12, 128, 64, 64] --
│ └─Conv2d: 2-56 [12, 128, 64, 64] 589,952
│ └─BatchNorm2d: 2-57 [12, 128, 64, 64] 256
│ └─ReLU: 2-58 [12, 128, 64, 64] --
│ └─Conv2d: 2-59 [12, 128, 64, 64] 147,584
│ └─BatchNorm2d: 2-60 [12, 128, 64, 64] 256
│ └─ReLU: 2-61 [12, 128, 64, 64] --
├─Sequential: 1-10 [12, 128, 128, 128] --
│ └─ConvTranspose2d: 2-62 [12, 128, 128, 128] 65,664
│ └─BatchNorm2d: 2-63 [12, 128, 128, 128] 256
│ └─ReLU: 2-64 [12, 128, 128, 128] --
├─Sequential: 1-11 [12, 64, 128, 128] --
│ └─Conv2d: 2-65 [12, 64, 128, 128] 147,520
│ └─BatchNorm2d: 2-66 [12, 64, 128, 128] 128
│ └─ReLU: 2-67 [12, 64, 128, 128] --
│ └─Conv2d: 2-68 [12, 64, 128, 128] 36,928
│ └─BatchNorm2d: 2-69 [12, 64, 128, 128] 128
│ └─ReLU: 2-70 [12, 64, 128, 128] --
├─Sequential: 1-12 [12, 64, 256, 256] --
│ └─ConvTranspose2d: 2-71 [12, 64, 256, 256] 16,448
│ └─BatchNorm2d: 2-72 [12, 64, 256, 256] 128
│ └─ReLU: 2-73 [12, 64, 256, 256] --
├─Sequential: 1-13 [12, 32, 256, 256] --
│ └─Conv2d: 2-74 [12, 32, 256, 256] 36,896
│ └─BatchNorm2d: 2-75 [12, 32, 256, 256] 64
│ └─ReLU: 2-76 [12, 32, 256, 256] --
│ └─Conv2d: 2-77 [12, 32, 256, 256] 9,248
│ └─BatchNorm2d: 2-78 [12, 32, 256, 256] 64
│ └─ReLU: 2-79 [12, 32, 256, 256] --
├─Conv2d: 1-14 [12, 10, 256, 256] 330
==========================================================================================
Total params: 20,038,666
Trainable params: 20,038,666
Non-trainable params: 0
Total mult-adds (G): 437.69
==========================================================================================
Input size (MB): 9.44
Forward/backward pass size (MB): 6480.20
Params size (MB): 80.15
Estimated Total Size (MB): 6569.79
==========================================================================================
I next demonstrate using for loops to freeze the trainable parameters in the backbone or encoder layers. This is accomplished by iterating over all layers in the list of layers in the VGGNet-16 model followed by iterating over all of the parameters in each of these layers to set the requires_grad property to False.
If I print the summary again, you can now see that only a subset of the total parameters is trainable. In other words, the parameters in the encoder component of the model that were defined using the VGGNet-16 model can no longer be updated. Only parameters in the decoder component will be trainable.
for l in model.base_model:
for param in l.parameters():
= False param.requires_grad
12,3,256,256)) summary(model, (
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
myUNetVGG16 [12, 10, 256, 256] --
├─Sequential: 1-1 [12, 64, 256, 256] --
│ └─Conv2d: 2-1 [12, 64, 256, 256] (1,792)
│ └─BatchNorm2d: 2-2 [12, 64, 256, 256] (128)
│ └─ReLU: 2-3 [12, 64, 256, 256] --
│ └─Conv2d: 2-4 [12, 64, 256, 256] (36,928)
│ └─BatchNorm2d: 2-5 [12, 64, 256, 256] (128)
│ └─ReLU: 2-6 [12, 64, 256, 256] --
├─Sequential: 1-2 [12, 128, 128, 128] --
│ └─MaxPool2d: 2-7 [12, 64, 128, 128] --
│ └─Conv2d: 2-8 [12, 128, 128, 128] (73,856)
│ └─BatchNorm2d: 2-9 [12, 128, 128, 128] (256)
│ └─ReLU: 2-10 [12, 128, 128, 128] --
│ └─Conv2d: 2-11 [12, 128, 128, 128] (147,584)
│ └─BatchNorm2d: 2-12 [12, 128, 128, 128] (256)
│ └─ReLU: 2-13 [12, 128, 128, 128] --
├─Sequential: 1-3 [12, 256, 64, 64] --
│ └─MaxPool2d: 2-14 [12, 128, 64, 64] --
│ └─Conv2d: 2-15 [12, 256, 64, 64] (295,168)
│ └─BatchNorm2d: 2-16 [12, 256, 64, 64] (512)
│ └─ReLU: 2-17 [12, 256, 64, 64] --
│ └─Conv2d: 2-18 [12, 256, 64, 64] (590,080)
│ └─BatchNorm2d: 2-19 [12, 256, 64, 64] (512)
│ └─ReLU: 2-20 [12, 256, 64, 64] --
│ └─Conv2d: 2-21 [12, 256, 64, 64] (590,080)
│ └─BatchNorm2d: 2-22 [12, 256, 64, 64] (512)
│ └─ReLU: 2-23 [12, 256, 64, 64] --
├─Sequential: 1-4 [12, 512, 32, 32] --
│ └─MaxPool2d: 2-24 [12, 256, 32, 32] --
│ └─Conv2d: 2-25 [12, 512, 32, 32] (1,180,160)
│ └─BatchNorm2d: 2-26 [12, 512, 32, 32] (1,024)
│ └─ReLU: 2-27 [12, 512, 32, 32] --
│ └─Conv2d: 2-28 [12, 512, 32, 32] (2,359,808)
│ └─BatchNorm2d: 2-29 [12, 512, 32, 32] (1,024)
│ └─ReLU: 2-30 [12, 512, 32, 32] --
│ └─Conv2d: 2-31 [12, 512, 32, 32] (2,359,808)
│ └─BatchNorm2d: 2-32 [12, 512, 32, 32] (1,024)
│ └─ReLU: 2-33 [12, 512, 32, 32] --
├─Sequential: 1-5 [12, 512, 16, 16] --
│ └─MaxPool2d: 2-34 [12, 512, 16, 16] --
│ └─Conv2d: 2-35 [12, 512, 16, 16] (2,359,808)
│ └─BatchNorm2d: 2-36 [12, 512, 16, 16] (1,024)
│ └─ReLU: 2-37 [12, 512, 16, 16] --
│ └─Conv2d: 2-38 [12, 512, 16, 16] (2,359,808)
│ └─BatchNorm2d: 2-39 [12, 512, 16, 16] (1,024)
│ └─ReLU: 2-40 [12, 512, 16, 16] --
│ └─Conv2d: 2-41 [12, 512, 16, 16] (2,359,808)
│ └─BatchNorm2d: 2-42 [12, 512, 16, 16] (1,024)
│ └─ReLU: 2-43 [12, 512, 16, 16] --
├─Sequential: 1-6 [12, 512, 32, 32] --
│ └─ConvTranspose2d: 2-44 [12, 512, 32, 32] 1,049,088
│ └─BatchNorm2d: 2-45 [12, 512, 32, 32] 1,024
│ └─ReLU: 2-46 [12, 512, 32, 32] --
├─Sequential: 1-7 [12, 256, 32, 32] --
│ └─Conv2d: 2-47 [12, 256, 32, 32] 2,359,552
│ └─BatchNorm2d: 2-48 [12, 256, 32, 32] 512
│ └─ReLU: 2-49 [12, 256, 32, 32] --
│ └─Conv2d: 2-50 [12, 256, 32, 32] 590,080
│ └─BatchNorm2d: 2-51 [12, 256, 32, 32] 512
│ └─ReLU: 2-52 [12, 256, 32, 32] --
├─Sequential: 1-8 [12, 256, 64, 64] --
│ └─ConvTranspose2d: 2-53 [12, 256, 64, 64] 262,400
│ └─BatchNorm2d: 2-54 [12, 256, 64, 64] 512
│ └─ReLU: 2-55 [12, 256, 64, 64] --
├─Sequential: 1-9 [12, 128, 64, 64] --
│ └─Conv2d: 2-56 [12, 128, 64, 64] 589,952
│ └─BatchNorm2d: 2-57 [12, 128, 64, 64] 256
│ └─ReLU: 2-58 [12, 128, 64, 64] --
│ └─Conv2d: 2-59 [12, 128, 64, 64] 147,584
│ └─BatchNorm2d: 2-60 [12, 128, 64, 64] 256
│ └─ReLU: 2-61 [12, 128, 64, 64] --
├─Sequential: 1-10 [12, 128, 128, 128] --
│ └─ConvTranspose2d: 2-62 [12, 128, 128, 128] 65,664
│ └─BatchNorm2d: 2-63 [12, 128, 128, 128] 256
│ └─ReLU: 2-64 [12, 128, 128, 128] --
├─Sequential: 1-11 [12, 64, 128, 128] --
│ └─Conv2d: 2-65 [12, 64, 128, 128] 147,520
│ └─BatchNorm2d: 2-66 [12, 64, 128, 128] 128
│ └─ReLU: 2-67 [12, 64, 128, 128] --
│ └─Conv2d: 2-68 [12, 64, 128, 128] 36,928
│ └─BatchNorm2d: 2-69 [12, 64, 128, 128] 128
│ └─ReLU: 2-70 [12, 64, 128, 128] --
├─Sequential: 1-12 [12, 64, 256, 256] --
│ └─ConvTranspose2d: 2-71 [12, 64, 256, 256] 16,448
│ └─BatchNorm2d: 2-72 [12, 64, 256, 256] 128
│ └─ReLU: 2-73 [12, 64, 256, 256] --
├─Sequential: 1-13 [12, 32, 256, 256] --
│ └─Conv2d: 2-74 [12, 32, 256, 256] 36,896
│ └─BatchNorm2d: 2-75 [12, 32, 256, 256] 64
│ └─ReLU: 2-76 [12, 32, 256, 256] --
│ └─Conv2d: 2-77 [12, 32, 256, 256] 9,248
│ └─BatchNorm2d: 2-78 [12, 32, 256, 256] 64
│ └─ReLU: 2-79 [12, 32, 256, 256] --
├─Conv2d: 1-14 [12, 10, 256, 256] 330
==========================================================================================
Total params: 20,038,666
Trainable params: 5,315,530
Non-trainable params: 14,723,136
Total mult-adds (G): 437.69
==========================================================================================
Input size (MB): 9.44
Forward/backward pass size (MB): 6480.20
Params size (MB): 80.15
Estimated Total Size (MB): 6569.79
==========================================================================================
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
myUNetVGG16 [12, 10, 256, 256] --
├─Sequential: 1-1 [12, 64, 256, 256] --
│ └─Conv2d: 2-1 [12, 64, 256, 256] (1,792)
│ └─BatchNorm2d: 2-2 [12, 64, 256, 256] (128)
│ └─ReLU: 2-3 [12, 64, 256, 256] --
│ └─Conv2d: 2-4 [12, 64, 256, 256] (36,928)
│ └─BatchNorm2d: 2-5 [12, 64, 256, 256] (128)
│ └─ReLU: 2-6 [12, 64, 256, 256] --
├─Sequential: 1-2 [12, 128, 128, 128] --
│ └─MaxPool2d: 2-7 [12, 64, 128, 128] --
│ └─Conv2d: 2-8 [12, 128, 128, 128] (73,856)
│ └─BatchNorm2d: 2-9 [12, 128, 128, 128] (256)
│ └─ReLU: 2-10 [12, 128, 128, 128] --
│ └─Conv2d: 2-11 [12, 128, 128, 128] (147,584)
│ └─BatchNorm2d: 2-12 [12, 128, 128, 128] (256)
│ └─ReLU: 2-13 [12, 128, 128, 128] --
├─Sequential: 1-3 [12, 256, 64, 64] --
│ └─MaxPool2d: 2-14 [12, 128, 64, 64] --
│ └─Conv2d: 2-15 [12, 256, 64, 64] (295,168)
│ └─BatchNorm2d: 2-16 [12, 256, 64, 64] (512)
│ └─ReLU: 2-17 [12, 256, 64, 64] --
│ └─Conv2d: 2-18 [12, 256, 64, 64] (590,080)
│ └─BatchNorm2d: 2-19 [12, 256, 64, 64] (512)
│ └─ReLU: 2-20 [12, 256, 64, 64] --
│ └─Conv2d: 2-21 [12, 256, 64, 64] (590,080)
│ └─BatchNorm2d: 2-22 [12, 256, 64, 64] (512)
│ └─ReLU: 2-23 [12, 256, 64, 64] --
├─Sequential: 1-4 [12, 512, 32, 32] --
│ └─MaxPool2d: 2-24 [12, 256, 32, 32] --
│ └─Conv2d: 2-25 [12, 512, 32, 32] (1,180,160)
│ └─BatchNorm2d: 2-26 [12, 512, 32, 32] (1,024)
│ └─ReLU: 2-27 [12, 512, 32, 32] --
│ └─Conv2d: 2-28 [12, 512, 32, 32] (2,359,808)
│ └─BatchNorm2d: 2-29 [12, 512, 32, 32] (1,024)
│ └─ReLU: 2-30 [12, 512, 32, 32] --
│ └─Conv2d: 2-31 [12, 512, 32, 32] (2,359,808)
│ └─BatchNorm2d: 2-32 [12, 512, 32, 32] (1,024)
│ └─ReLU: 2-33 [12, 512, 32, 32] --
├─Sequential: 1-5 [12, 512, 16, 16] --
│ └─MaxPool2d: 2-34 [12, 512, 16, 16] --
│ └─Conv2d: 2-35 [12, 512, 16, 16] (2,359,808)
│ └─BatchNorm2d: 2-36 [12, 512, 16, 16] (1,024)
│ └─ReLU: 2-37 [12, 512, 16, 16] --
│ └─Conv2d: 2-38 [12, 512, 16, 16] (2,359,808)
│ └─BatchNorm2d: 2-39 [12, 512, 16, 16] (1,024)
│ └─ReLU: 2-40 [12, 512, 16, 16] --
│ └─Conv2d: 2-41 [12, 512, 16, 16] (2,359,808)
│ └─BatchNorm2d: 2-42 [12, 512, 16, 16] (1,024)
│ └─ReLU: 2-43 [12, 512, 16, 16] --
├─Sequential: 1-6 [12, 512, 32, 32] --
│ └─ConvTranspose2d: 2-44 [12, 512, 32, 32] 1,049,088
│ └─BatchNorm2d: 2-45 [12, 512, 32, 32] 1,024
│ └─ReLU: 2-46 [12, 512, 32, 32] --
├─Sequential: 1-7 [12, 256, 32, 32] --
│ └─Conv2d: 2-47 [12, 256, 32, 32] 2,359,552
│ └─BatchNorm2d: 2-48 [12, 256, 32, 32] 512
│ └─ReLU: 2-49 [12, 256, 32, 32] --
│ └─Conv2d: 2-50 [12, 256, 32, 32] 590,080
│ └─BatchNorm2d: 2-51 [12, 256, 32, 32] 512
│ └─ReLU: 2-52 [12, 256, 32, 32] --
├─Sequential: 1-8 [12, 256, 64, 64] --
│ └─ConvTranspose2d: 2-53 [12, 256, 64, 64] 262,400
│ └─BatchNorm2d: 2-54 [12, 256, 64, 64] 512
│ └─ReLU: 2-55 [12, 256, 64, 64] --
├─Sequential: 1-9 [12, 128, 64, 64] --
│ └─Conv2d: 2-56 [12, 128, 64, 64] 589,952
│ └─BatchNorm2d: 2-57 [12, 128, 64, 64] 256
│ └─ReLU: 2-58 [12, 128, 64, 64] --
│ └─Conv2d: 2-59 [12, 128, 64, 64] 147,584
│ └─BatchNorm2d: 2-60 [12, 128, 64, 64] 256
│ └─ReLU: 2-61 [12, 128, 64, 64] --
├─Sequential: 1-10 [12, 128, 128, 128] --
│ └─ConvTranspose2d: 2-62 [12, 128, 128, 128] 65,664
│ └─BatchNorm2d: 2-63 [12, 128, 128, 128] 256
│ └─ReLU: 2-64 [12, 128, 128, 128] --
├─Sequential: 1-11 [12, 64, 128, 128] --
│ └─Conv2d: 2-65 [12, 64, 128, 128] 147,520
│ └─BatchNorm2d: 2-66 [12, 64, 128, 128] 128
│ └─ReLU: 2-67 [12, 64, 128, 128] --
│ └─Conv2d: 2-68 [12, 64, 128, 128] 36,928
│ └─BatchNorm2d: 2-69 [12, 64, 128, 128] 128
│ └─ReLU: 2-70 [12, 64, 128, 128] --
├─Sequential: 1-12 [12, 64, 256, 256] --
│ └─ConvTranspose2d: 2-71 [12, 64, 256, 256] 16,448
│ └─BatchNorm2d: 2-72 [12, 64, 256, 256] 128
│ └─ReLU: 2-73 [12, 64, 256, 256] --
├─Sequential: 1-13 [12, 32, 256, 256] --
│ └─Conv2d: 2-74 [12, 32, 256, 256] 36,896
│ └─BatchNorm2d: 2-75 [12, 32, 256, 256] 64
│ └─ReLU: 2-76 [12, 32, 256, 256] --
│ └─Conv2d: 2-77 [12, 32, 256, 256] 9,248
│ └─BatchNorm2d: 2-78 [12, 32, 256, 256] 64
│ └─ReLU: 2-79 [12, 32, 256, 256] --
├─Conv2d: 1-14 [12, 10, 256, 256] 330
==========================================================================================
Total params: 20,038,666
Trainable params: 5,315,530
Non-trainable params: 14,723,136
Total mult-adds (G): 437.69
==========================================================================================
Input size (MB): 9.44
Forward/backward pass size (MB): 6480.20
Params size (MB): 80.15
Estimated Total Size (MB): 6569.79
==========================================================================================
ResNet Encoder
As a second example, I will now define a UNet architecture that can accept different versions of a ResNet architecture as the backbone. I begin by redefining the double_conv() and up_conv() functions. In the myUNetResNet() subclass definition, I define the following parameters
- inChn = number of input channels
- inCls = number of output classes
- resNet = which ResNet architecture to use (“18”, “34”, or “50”); the default is “18”
- useWhgts = whether or not to initialize using pre-trained weights.
The backbone used will depend on the argument provided to the resNet parameter. Note that this will not change how the encoder blocks are defined since the list of layers does not change. The type or ReNet varies based on the number of operations in each layer, but the total number of layers remains the same. Similar to the UNet using the VGGNet-16 backbone, the layers from the ResNet must be partitioned into the appropriate encoder blocks such that the correct array size is extracted to concatenate with the associated decoder block. The first operation in the ResNet architecture actually reduces the array size, so the first encoder block uses the defined double_conv() operation as opposed to the first set of layers in the ResNet. Portions of the ResNet are used to define the operations in the next 4 encoder blocks and the bottleneck layer. Once these components are defined, the decoder components are defined, which consist of upsampling with 2D transpose convolution and learning new filters with a sequence of two 2D convolution layers.
In the forward method, as normal, how the data pass through the architecture is defined. The input data will first pass through the first encoder block. Since this is not part of the ResNet architecture, the original data are also passed through the second encoder block as opposed to the output from the first encoder block. The 2nd through 5th encoder blocks and the bottleneck block are defined based on components of the ResNet architecture.
Next, the data pass through the decoder component of the architecture. Each block consists of upsampling with 2D transpose convolution, concatenation of the feature maps from the associated encoder block, and using two 2D convolution layers to learn additional filters. Lastly, the data are then passed trough a 2D convolutional layer with a kernel size of 1x1 and a stride of 1 to obtain the class logits.
Again, it is important here that the correct stage of the ResNet architecture be assigned to the correct encoder block so that the skip connections will deliver arrays with the correct sizes in the spatial dimensions, which will then be concatenated with the layers from the bottleneck or prior decoder block.
def double_conv(inChannels, outChannels):
return nn.Sequential(
=(3,3), stride=1, padding=1),
nn.Conv2d(inChannels, outChannels, kernel_size
nn.BatchNorm2d(outChannels),=True),
nn.ReLU(inplace=(3,3), stride=1, padding=1),
nn.Conv2d(outChannels, outChannels, kernel_size
nn.BatchNorm2d(outChannels),=True)
nn.ReLU(inplace )
def up_conv(inChannels, outChannels):
return nn.Sequential(
=(2,2), stride=2),
nn.ConvTranspose2d(inChannels, outChannels, kernel_size
nn.BatchNorm2d(outChannels),=True)
nn.ReLU(inplace )
class myUNetResNet(nn.Module):
def __init__(self, inChn, nCls, resNet = "18", useWghts=True):
super().__init__()
self.inChn = inChn
self.nCls = nCls
self.resNet = resNet
self.useWghts = useWghts
if(resNet == "34"):
self.base_model = torchvision.models.resnet34(pretrained=useWghts)
self.base_layers = list(self.base_model.children())
self.outSizes = [64, 64, 128, 256, 512]
elif(resNet == "50"):
self.base_model = torchvision.models.resnet50(pretrained=useWghts)
self.base_layers = list(self.base_model.children())
self.outSizes = [64, 256, 512, 1024, 2048]
else:
self.base_model = torchvision.models.resnet18(pretrained=useWghts)
self.base_layers = list(self.base_model.children())
self.outSizes = [64, 64, 128, 256, 512]
self.encoder1 = double_conv(inChn, 16)
self.encoder2 = nn.Sequential(*self.base_layers[:3])
self.encoder3 = nn.Sequential(*self.base_layers[3:5])
self.encoder4 = self.base_layers[5]
self.encoder5 = self.base_layers[6]
self.bottleneck = self.base_layers[7]
self.decoder1up = up_conv(self.outSizes[4], 512)
self.decoder1 = double_conv(self.outSizes[3] + 512, 256)
self.decoder2up = up_conv(256, 256)
self.decoder2 = double_conv(self.outSizes[2] + 256, 128)
self.decoder3up = up_conv(128, 128)
self.decoder3 = double_conv(self.outSizes[1] + 128, 64)
self.decoder4up = up_conv(64, 64)
self.decoder4 = double_conv(self.outSizes[0] + 64, 32)
self.decoder5up = up_conv(32, 32)
self.decoder5 = double_conv(16 + 32, 16)
self.classifier = nn.Conv2d(16, nCls, kernel_size=(1,1))
def forward(self, x):
#Encoder
= self.encoder1(x)
encoder1 = self.encoder2(x)
encoder2 = self.encoder3(encoder2)
encoder3 = self.encoder4(encoder3)
encoder4 = self.encoder5(encoder4)
encoder5
#Bottleneck
= self.bottleneck(encoder5)
x
#Decoder
= self.decoder1up(x)
x = torch.concat([x, encoder5], dim=1)
x = self.decoder1(x)
x
= self.decoder2up(x)
x = torch.concat([x, encoder4], dim=1)
x = self.decoder2(x)
x
= self.decoder3up(x)
x = torch.concat([x, encoder3], dim=1)
x = self.decoder3(x)
x
= self.decoder4up(x)
x = torch.concat([x, encoder2], dim=1)
x = self.decoder4(x)
x
= self.decoder5up(x)
x = torch.concat([x, encoder1], dim=1)
x = self.decoder5(x)
x
#Classifier head
= self.classifier(x)
x
return x
I instantiate an instance of the myUnetResNet subclass that accepts 3 channels, differentiates 10 classes, uses a ResNet-18 architecture in the backbone or encoder, and is initialized using pre-trained weights in the encoder. I print a summary to explore the model architecture, which has over 26 million trainable parameter.
= myUNetResNet(inChn=3, nCls=10, resNet = "18", useWghts=True).to(device) model
C:\Users\vidcg\ANACON~1\envs\torchENV\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
12,3,256,256)) summary(model, (
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
myUNetResNet [12, 10, 256, 256] 513,000
├─Sequential: 1-1 [12, 16, 256, 256] --
│ └─Conv2d: 2-1 [12, 16, 256, 256] 448
│ └─BatchNorm2d: 2-2 [12, 16, 256, 256] 32
│ └─ReLU: 2-3 [12, 16, 256, 256] --
│ └─Conv2d: 2-4 [12, 16, 256, 256] 2,320
│ └─BatchNorm2d: 2-5 [12, 16, 256, 256] 32
│ └─ReLU: 2-6 [12, 16, 256, 256] --
├─Sequential: 1-2 [12, 64, 128, 128] --
│ └─Conv2d: 2-7 [12, 64, 128, 128] 9,408
│ └─BatchNorm2d: 2-8 [12, 64, 128, 128] 128
│ └─ReLU: 2-9 [12, 64, 128, 128] --
├─Sequential: 1-3 [12, 64, 64, 64] --
│ └─MaxPool2d: 2-10 [12, 64, 64, 64] --
│ └─Sequential: 2-11 [12, 64, 64, 64] --
│ │ └─BasicBlock: 3-1 [12, 64, 64, 64] 73,984
│ │ └─BasicBlock: 3-2 [12, 64, 64, 64] 73,984
├─Sequential: 1-4 [12, 128, 32, 32] --
│ └─BasicBlock: 2-12 [12, 128, 32, 32] --
│ │ └─Conv2d: 3-3 [12, 128, 32, 32] 73,728
│ │ └─BatchNorm2d: 3-4 [12, 128, 32, 32] 256
│ │ └─ReLU: 3-5 [12, 128, 32, 32] --
│ │ └─Conv2d: 3-6 [12, 128, 32, 32] 147,456
│ │ └─BatchNorm2d: 3-7 [12, 128, 32, 32] 256
│ │ └─Sequential: 3-8 [12, 128, 32, 32] 8,448
│ │ └─ReLU: 3-9 [12, 128, 32, 32] --
│ └─BasicBlock: 2-13 [12, 128, 32, 32] --
│ │ └─Conv2d: 3-10 [12, 128, 32, 32] 147,456
│ │ └─BatchNorm2d: 3-11 [12, 128, 32, 32] 256
│ │ └─ReLU: 3-12 [12, 128, 32, 32] --
│ │ └─Conv2d: 3-13 [12, 128, 32, 32] 147,456
│ │ └─BatchNorm2d: 3-14 [12, 128, 32, 32] 256
│ │ └─ReLU: 3-15 [12, 128, 32, 32] --
├─Sequential: 1-5 [12, 256, 16, 16] --
│ └─BasicBlock: 2-14 [12, 256, 16, 16] --
│ │ └─Conv2d: 3-16 [12, 256, 16, 16] 294,912
│ │ └─BatchNorm2d: 3-17 [12, 256, 16, 16] 512
│ │ └─ReLU: 3-18 [12, 256, 16, 16] --
│ │ └─Conv2d: 3-19 [12, 256, 16, 16] 589,824
│ │ └─BatchNorm2d: 3-20 [12, 256, 16, 16] 512
│ │ └─Sequential: 3-21 [12, 256, 16, 16] 33,280
│ │ └─ReLU: 3-22 [12, 256, 16, 16] --
│ └─BasicBlock: 2-15 [12, 256, 16, 16] --
│ │ └─Conv2d: 3-23 [12, 256, 16, 16] 589,824
│ │ └─BatchNorm2d: 3-24 [12, 256, 16, 16] 512
│ │ └─ReLU: 3-25 [12, 256, 16, 16] --
│ │ └─Conv2d: 3-26 [12, 256, 16, 16] 589,824
│ │ └─BatchNorm2d: 3-27 [12, 256, 16, 16] 512
│ │ └─ReLU: 3-28 [12, 256, 16, 16] --
├─Sequential: 1-6 [12, 512, 8, 8] --
│ └─BasicBlock: 2-16 [12, 512, 8, 8] --
│ │ └─Conv2d: 3-29 [12, 512, 8, 8] 1,179,648
│ │ └─BatchNorm2d: 3-30 [12, 512, 8, 8] 1,024
│ │ └─ReLU: 3-31 [12, 512, 8, 8] --
│ │ └─Conv2d: 3-32 [12, 512, 8, 8] 2,359,296
│ │ └─BatchNorm2d: 3-33 [12, 512, 8, 8] 1,024
│ │ └─Sequential: 3-34 [12, 512, 8, 8] 132,096
│ │ └─ReLU: 3-35 [12, 512, 8, 8] --
│ └─BasicBlock: 2-17 [12, 512, 8, 8] --
│ │ └─Conv2d: 3-36 [12, 512, 8, 8] 2,359,296
│ │ └─BatchNorm2d: 3-37 [12, 512, 8, 8] 1,024
│ │ └─ReLU: 3-38 [12, 512, 8, 8] --
│ │ └─Conv2d: 3-39 [12, 512, 8, 8] 2,359,296
│ │ └─BatchNorm2d: 3-40 [12, 512, 8, 8] 1,024
│ │ └─ReLU: 3-41 [12, 512, 8, 8] --
├─Sequential: 1-7 [12, 512, 16, 16] --
│ └─ConvTranspose2d: 2-18 [12, 512, 16, 16] 1,049,088
│ └─BatchNorm2d: 2-19 [12, 512, 16, 16] 1,024
│ └─ReLU: 2-20 [12, 512, 16, 16] --
├─Sequential: 1-8 [12, 256, 16, 16] --
│ └─Conv2d: 2-21 [12, 256, 16, 16] 1,769,728
│ └─BatchNorm2d: 2-22 [12, 256, 16, 16] 512
│ └─ReLU: 2-23 [12, 256, 16, 16] --
│ └─Conv2d: 2-24 [12, 256, 16, 16] 590,080
│ └─BatchNorm2d: 2-25 [12, 256, 16, 16] 512
│ └─ReLU: 2-26 [12, 256, 16, 16] --
├─Sequential: 1-9 [12, 256, 32, 32] --
│ └─ConvTranspose2d: 2-27 [12, 256, 32, 32] 262,400
│ └─BatchNorm2d: 2-28 [12, 256, 32, 32] 512
│ └─ReLU: 2-29 [12, 256, 32, 32] --
├─Sequential: 1-10 [12, 128, 32, 32] --
│ └─Conv2d: 2-30 [12, 128, 32, 32] 442,496
│ └─BatchNorm2d: 2-31 [12, 128, 32, 32] 256
│ └─ReLU: 2-32 [12, 128, 32, 32] --
│ └─Conv2d: 2-33 [12, 128, 32, 32] 147,584
│ └─BatchNorm2d: 2-34 [12, 128, 32, 32] 256
│ └─ReLU: 2-35 [12, 128, 32, 32] --
├─Sequential: 1-11 [12, 128, 64, 64] --
│ └─ConvTranspose2d: 2-36 [12, 128, 64, 64] 65,664
│ └─BatchNorm2d: 2-37 [12, 128, 64, 64] 256
│ └─ReLU: 2-38 [12, 128, 64, 64] --
├─Sequential: 1-12 [12, 64, 64, 64] --
│ └─Conv2d: 2-39 [12, 64, 64, 64] 110,656
│ └─BatchNorm2d: 2-40 [12, 64, 64, 64] 128
│ └─ReLU: 2-41 [12, 64, 64, 64] --
│ └─Conv2d: 2-42 [12, 64, 64, 64] 36,928
│ └─BatchNorm2d: 2-43 [12, 64, 64, 64] 128
│ └─ReLU: 2-44 [12, 64, 64, 64] --
├─Sequential: 1-13 [12, 64, 128, 128] --
│ └─ConvTranspose2d: 2-45 [12, 64, 128, 128] 16,448
│ └─BatchNorm2d: 2-46 [12, 64, 128, 128] 128
│ └─ReLU: 2-47 [12, 64, 128, 128] --
├─Sequential: 1-14 [12, 32, 128, 128] --
│ └─Conv2d: 2-48 [12, 32, 128, 128] 36,896
│ └─BatchNorm2d: 2-49 [12, 32, 128, 128] 64
│ └─ReLU: 2-50 [12, 32, 128, 128] --
│ └─Conv2d: 2-51 [12, 32, 128, 128] 9,248
│ └─BatchNorm2d: 2-52 [12, 32, 128, 128] 64
│ └─ReLU: 2-53 [12, 32, 128, 128] --
├─Sequential: 1-15 [12, 32, 256, 256] --
│ └─ConvTranspose2d: 2-54 [12, 32, 256, 256] 4,128
│ └─BatchNorm2d: 2-55 [12, 32, 256, 256] 64
│ └─ReLU: 2-56 [12, 32, 256, 256] --
├─Sequential: 1-16 [12, 16, 256, 256] --
│ └─Conv2d: 2-57 [12, 16, 256, 256] 6,928
│ └─BatchNorm2d: 2-58 [12, 16, 256, 256] 32
│ └─ReLU: 2-59 [12, 16, 256, 256] --
│ └─Conv2d: 2-60 [12, 16, 256, 256] 2,320
│ └─BatchNorm2d: 2-61 [12, 16, 256, 256] 32
│ └─ReLU: 2-62 [12, 16, 256, 256] --
├─Conv2d: 1-17 [12, 10, 256, 256] 170
==========================================================================================
Total params: 16,247,074
Trainable params: 16,247,074
Non-trainable params: 0
Total mult-adds (G): 84.99
==========================================================================================
Input size (MB): 9.44
Forward/backward pass size (MB): 2648.70
Params size (MB): 62.94
Estimated Total Size (MB): 2721.08
==========================================================================================
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
myUNetResNet [12, 10, 256, 256] 513,000
├─Sequential: 1-1 [12, 16, 256, 256] --
│ └─Conv2d: 2-1 [12, 16, 256, 256] 448
│ └─BatchNorm2d: 2-2 [12, 16, 256, 256] 32
│ └─ReLU: 2-3 [12, 16, 256, 256] --
│ └─Conv2d: 2-4 [12, 16, 256, 256] 2,320
│ └─BatchNorm2d: 2-5 [12, 16, 256, 256] 32
│ └─ReLU: 2-6 [12, 16, 256, 256] --
├─Sequential: 1-2 [12, 64, 128, 128] --
│ └─Conv2d: 2-7 [12, 64, 128, 128] 9,408
│ └─BatchNorm2d: 2-8 [12, 64, 128, 128] 128
│ └─ReLU: 2-9 [12, 64, 128, 128] --
├─Sequential: 1-3 [12, 64, 64, 64] --
│ └─MaxPool2d: 2-10 [12, 64, 64, 64] --
│ └─Sequential: 2-11 [12, 64, 64, 64] --
│ │ └─BasicBlock: 3-1 [12, 64, 64, 64] 73,984
│ │ └─BasicBlock: 3-2 [12, 64, 64, 64] 73,984
├─Sequential: 1-4 [12, 128, 32, 32] --
│ └─BasicBlock: 2-12 [12, 128, 32, 32] --
│ │ └─Conv2d: 3-3 [12, 128, 32, 32] 73,728
│ │ └─BatchNorm2d: 3-4 [12, 128, 32, 32] 256
│ │ └─ReLU: 3-5 [12, 128, 32, 32] --
│ │ └─Conv2d: 3-6 [12, 128, 32, 32] 147,456
│ │ └─BatchNorm2d: 3-7 [12, 128, 32, 32] 256
│ │ └─Sequential: 3-8 [12, 128, 32, 32] 8,448
│ │ └─ReLU: 3-9 [12, 128, 32, 32] --
│ └─BasicBlock: 2-13 [12, 128, 32, 32] --
│ │ └─Conv2d: 3-10 [12, 128, 32, 32] 147,456
│ │ └─BatchNorm2d: 3-11 [12, 128, 32, 32] 256
│ │ └─ReLU: 3-12 [12, 128, 32, 32] --
│ │ └─Conv2d: 3-13 [12, 128, 32, 32] 147,456
│ │ └─BatchNorm2d: 3-14 [12, 128, 32, 32] 256
│ │ └─ReLU: 3-15 [12, 128, 32, 32] --
├─Sequential: 1-5 [12, 256, 16, 16] --
│ └─BasicBlock: 2-14 [12, 256, 16, 16] --
│ │ └─Conv2d: 3-16 [12, 256, 16, 16] 294,912
│ │ └─BatchNorm2d: 3-17 [12, 256, 16, 16] 512
│ │ └─ReLU: 3-18 [12, 256, 16, 16] --
│ │ └─Conv2d: 3-19 [12, 256, 16, 16] 589,824
│ │ └─BatchNorm2d: 3-20 [12, 256, 16, 16] 512
│ │ └─Sequential: 3-21 [12, 256, 16, 16] 33,280
│ │ └─ReLU: 3-22 [12, 256, 16, 16] --
│ └─BasicBlock: 2-15 [12, 256, 16, 16] --
│ │ └─Conv2d: 3-23 [12, 256, 16, 16] 589,824
│ │ └─BatchNorm2d: 3-24 [12, 256, 16, 16] 512
│ │ └─ReLU: 3-25 [12, 256, 16, 16] --
│ │ └─Conv2d: 3-26 [12, 256, 16, 16] 589,824
│ │ └─BatchNorm2d: 3-27 [12, 256, 16, 16] 512
│ │ └─ReLU: 3-28 [12, 256, 16, 16] --
├─Sequential: 1-6 [12, 512, 8, 8] --
│ └─BasicBlock: 2-16 [12, 512, 8, 8] --
│ │ └─Conv2d: 3-29 [12, 512, 8, 8] 1,179,648
│ │ └─BatchNorm2d: 3-30 [12, 512, 8, 8] 1,024
│ │ └─ReLU: 3-31 [12, 512, 8, 8] --
│ │ └─Conv2d: 3-32 [12, 512, 8, 8] 2,359,296
│ │ └─BatchNorm2d: 3-33 [12, 512, 8, 8] 1,024
│ │ └─Sequential: 3-34 [12, 512, 8, 8] 132,096
│ │ └─ReLU: 3-35 [12, 512, 8, 8] --
│ └─BasicBlock: 2-17 [12, 512, 8, 8] --
│ │ └─Conv2d: 3-36 [12, 512, 8, 8] 2,359,296
│ │ └─BatchNorm2d: 3-37 [12, 512, 8, 8] 1,024
│ │ └─ReLU: 3-38 [12, 512, 8, 8] --
│ │ └─Conv2d: 3-39 [12, 512, 8, 8] 2,359,296
│ │ └─BatchNorm2d: 3-40 [12, 512, 8, 8] 1,024
│ │ └─ReLU: 3-41 [12, 512, 8, 8] --
├─Sequential: 1-7 [12, 512, 16, 16] --
│ └─ConvTranspose2d: 2-18 [12, 512, 16, 16] 1,049,088
│ └─BatchNorm2d: 2-19 [12, 512, 16, 16] 1,024
│ └─ReLU: 2-20 [12, 512, 16, 16] --
├─Sequential: 1-8 [12, 256, 16, 16] --
│ └─Conv2d: 2-21 [12, 256, 16, 16] 1,769,728
│ └─BatchNorm2d: 2-22 [12, 256, 16, 16] 512
│ └─ReLU: 2-23 [12, 256, 16, 16] --
│ └─Conv2d: 2-24 [12, 256, 16, 16] 590,080
│ └─BatchNorm2d: 2-25 [12, 256, 16, 16] 512
│ └─ReLU: 2-26 [12, 256, 16, 16] --
├─Sequential: 1-9 [12, 256, 32, 32] --
│ └─ConvTranspose2d: 2-27 [12, 256, 32, 32] 262,400
│ └─BatchNorm2d: 2-28 [12, 256, 32, 32] 512
│ └─ReLU: 2-29 [12, 256, 32, 32] --
├─Sequential: 1-10 [12, 128, 32, 32] --
│ └─Conv2d: 2-30 [12, 128, 32, 32] 442,496
│ └─BatchNorm2d: 2-31 [12, 128, 32, 32] 256
│ └─ReLU: 2-32 [12, 128, 32, 32] --
│ └─Conv2d: 2-33 [12, 128, 32, 32] 147,584
│ └─BatchNorm2d: 2-34 [12, 128, 32, 32] 256
│ └─ReLU: 2-35 [12, 128, 32, 32] --
├─Sequential: 1-11 [12, 128, 64, 64] --
│ └─ConvTranspose2d: 2-36 [12, 128, 64, 64] 65,664
│ └─BatchNorm2d: 2-37 [12, 128, 64, 64] 256
│ └─ReLU: 2-38 [12, 128, 64, 64] --
├─Sequential: 1-12 [12, 64, 64, 64] --
│ └─Conv2d: 2-39 [12, 64, 64, 64] 110,656
│ └─BatchNorm2d: 2-40 [12, 64, 64, 64] 128
│ └─ReLU: 2-41 [12, 64, 64, 64] --
│ └─Conv2d: 2-42 [12, 64, 64, 64] 36,928
│ └─BatchNorm2d: 2-43 [12, 64, 64, 64] 128
│ └─ReLU: 2-44 [12, 64, 64, 64] --
├─Sequential: 1-13 [12, 64, 128, 128] --
│ └─ConvTranspose2d: 2-45 [12, 64, 128, 128] 16,448
│ └─BatchNorm2d: 2-46 [12, 64, 128, 128] 128
│ └─ReLU: 2-47 [12, 64, 128, 128] --
├─Sequential: 1-14 [12, 32, 128, 128] --
│ └─Conv2d: 2-48 [12, 32, 128, 128] 36,896
│ └─BatchNorm2d: 2-49 [12, 32, 128, 128] 64
│ └─ReLU: 2-50 [12, 32, 128, 128] --
│ └─Conv2d: 2-51 [12, 32, 128, 128] 9,248
│ └─BatchNorm2d: 2-52 [12, 32, 128, 128] 64
│ └─ReLU: 2-53 [12, 32, 128, 128] --
├─Sequential: 1-15 [12, 32, 256, 256] --
│ └─ConvTranspose2d: 2-54 [12, 32, 256, 256] 4,128
│ └─BatchNorm2d: 2-55 [12, 32, 256, 256] 64
│ └─ReLU: 2-56 [12, 32, 256, 256] --
├─Sequential: 1-16 [12, 16, 256, 256] --
│ └─Conv2d: 2-57 [12, 16, 256, 256] 6,928
│ └─BatchNorm2d: 2-58 [12, 16, 256, 256] 32
│ └─ReLU: 2-59 [12, 16, 256, 256] --
│ └─Conv2d: 2-60 [12, 16, 256, 256] 2,320
│ └─BatchNorm2d: 2-61 [12, 16, 256, 256] 32
│ └─ReLU: 2-62 [12, 16, 256, 256] --
├─Conv2d: 1-17 [12, 10, 256, 256] 170
==========================================================================================
Total params: 16,247,074
Trainable params: 16,247,074
Non-trainable params: 0
Total mult-adds (G): 84.99
==========================================================================================
Input size (MB): 9.44
Forward/backward pass size (MB): 2648.70
Params size (MB): 62.94
Estimated Total Size (MB): 2721.08
==========================================================================================
Similar to the VGGNet-16 example above, I can freeze the backbone parameters to reduce the number of trainable parameters in the model. This is accomplished by setting the requires_grad property for the backbone or encoder layers extracted from the ResNet architecture to False. Printing the summary, you can see that only a subset of the parameters is now trainable.
for l in model.base_layers:
for param in l.parameters():
= False param.requires_grad
12,3,256,256)) summary(model, (
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
myUNetResNet [12, 10, 256, 256] 513,000
├─Sequential: 1-1 [12, 16, 256, 256] --
│ └─Conv2d: 2-1 [12, 16, 256, 256] 448
│ └─BatchNorm2d: 2-2 [12, 16, 256, 256] 32
│ └─ReLU: 2-3 [12, 16, 256, 256] --
│ └─Conv2d: 2-4 [12, 16, 256, 256] 2,320
│ └─BatchNorm2d: 2-5 [12, 16, 256, 256] 32
│ └─ReLU: 2-6 [12, 16, 256, 256] --
├─Sequential: 1-2 [12, 64, 128, 128] --
│ └─Conv2d: 2-7 [12, 64, 128, 128] (9,408)
│ └─BatchNorm2d: 2-8 [12, 64, 128, 128] (128)
│ └─ReLU: 2-9 [12, 64, 128, 128] --
├─Sequential: 1-3 [12, 64, 64, 64] --
│ └─MaxPool2d: 2-10 [12, 64, 64, 64] --
│ └─Sequential: 2-11 [12, 64, 64, 64] --
│ │ └─BasicBlock: 3-1 [12, 64, 64, 64] (73,984)
│ │ └─BasicBlock: 3-2 [12, 64, 64, 64] (73,984)
├─Sequential: 1-4 [12, 128, 32, 32] --
│ └─BasicBlock: 2-12 [12, 128, 32, 32] --
│ │ └─Conv2d: 3-3 [12, 128, 32, 32] (73,728)
│ │ └─BatchNorm2d: 3-4 [12, 128, 32, 32] (256)
│ │ └─ReLU: 3-5 [12, 128, 32, 32] --
│ │ └─Conv2d: 3-6 [12, 128, 32, 32] (147,456)
│ │ └─BatchNorm2d: 3-7 [12, 128, 32, 32] (256)
│ │ └─Sequential: 3-8 [12, 128, 32, 32] (8,448)
│ │ └─ReLU: 3-9 [12, 128, 32, 32] --
│ └─BasicBlock: 2-13 [12, 128, 32, 32] --
│ │ └─Conv2d: 3-10 [12, 128, 32, 32] (147,456)
│ │ └─BatchNorm2d: 3-11 [12, 128, 32, 32] (256)
│ │ └─ReLU: 3-12 [12, 128, 32, 32] --
│ │ └─Conv2d: 3-13 [12, 128, 32, 32] (147,456)
│ │ └─BatchNorm2d: 3-14 [12, 128, 32, 32] (256)
│ │ └─ReLU: 3-15 [12, 128, 32, 32] --
├─Sequential: 1-5 [12, 256, 16, 16] --
│ └─BasicBlock: 2-14 [12, 256, 16, 16] --
│ │ └─Conv2d: 3-16 [12, 256, 16, 16] (294,912)
│ │ └─BatchNorm2d: 3-17 [12, 256, 16, 16] (512)
│ │ └─ReLU: 3-18 [12, 256, 16, 16] --
│ │ └─Conv2d: 3-19 [12, 256, 16, 16] (589,824)
│ │ └─BatchNorm2d: 3-20 [12, 256, 16, 16] (512)
│ │ └─Sequential: 3-21 [12, 256, 16, 16] (33,280)
│ │ └─ReLU: 3-22 [12, 256, 16, 16] --
│ └─BasicBlock: 2-15 [12, 256, 16, 16] --
│ │ └─Conv2d: 3-23 [12, 256, 16, 16] (589,824)
│ │ └─BatchNorm2d: 3-24 [12, 256, 16, 16] (512)
│ │ └─ReLU: 3-25 [12, 256, 16, 16] --
│ │ └─Conv2d: 3-26 [12, 256, 16, 16] (589,824)
│ │ └─BatchNorm2d: 3-27 [12, 256, 16, 16] (512)
│ │ └─ReLU: 3-28 [12, 256, 16, 16] --
├─Sequential: 1-6 [12, 512, 8, 8] --
│ └─BasicBlock: 2-16 [12, 512, 8, 8] --
│ │ └─Conv2d: 3-29 [12, 512, 8, 8] (1,179,648)
│ │ └─BatchNorm2d: 3-30 [12, 512, 8, 8] (1,024)
│ │ └─ReLU: 3-31 [12, 512, 8, 8] --
│ │ └─Conv2d: 3-32 [12, 512, 8, 8] (2,359,296)
│ │ └─BatchNorm2d: 3-33 [12, 512, 8, 8] (1,024)
│ │ └─Sequential: 3-34 [12, 512, 8, 8] (132,096)
│ │ └─ReLU: 3-35 [12, 512, 8, 8] --
│ └─BasicBlock: 2-17 [12, 512, 8, 8] --
│ │ └─Conv2d: 3-36 [12, 512, 8, 8] (2,359,296)
│ │ └─BatchNorm2d: 3-37 [12, 512, 8, 8] (1,024)
│ │ └─ReLU: 3-38 [12, 512, 8, 8] --
│ │ └─Conv2d: 3-39 [12, 512, 8, 8] (2,359,296)
│ │ └─BatchNorm2d: 3-40 [12, 512, 8, 8] (1,024)
│ │ └─ReLU: 3-41 [12, 512, 8, 8] --
├─Sequential: 1-7 [12, 512, 16, 16] --
│ └─ConvTranspose2d: 2-18 [12, 512, 16, 16] 1,049,088
│ └─BatchNorm2d: 2-19 [12, 512, 16, 16] 1,024
│ └─ReLU: 2-20 [12, 512, 16, 16] --
├─Sequential: 1-8 [12, 256, 16, 16] --
│ └─Conv2d: 2-21 [12, 256, 16, 16] 1,769,728
│ └─BatchNorm2d: 2-22 [12, 256, 16, 16] 512
│ └─ReLU: 2-23 [12, 256, 16, 16] --
│ └─Conv2d: 2-24 [12, 256, 16, 16] 590,080
│ └─BatchNorm2d: 2-25 [12, 256, 16, 16] 512
│ └─ReLU: 2-26 [12, 256, 16, 16] --
├─Sequential: 1-9 [12, 256, 32, 32] --
│ └─ConvTranspose2d: 2-27 [12, 256, 32, 32] 262,400
│ └─BatchNorm2d: 2-28 [12, 256, 32, 32] 512
│ └─ReLU: 2-29 [12, 256, 32, 32] --
├─Sequential: 1-10 [12, 128, 32, 32] --
│ └─Conv2d: 2-30 [12, 128, 32, 32] 442,496
│ └─BatchNorm2d: 2-31 [12, 128, 32, 32] 256
│ └─ReLU: 2-32 [12, 128, 32, 32] --
│ └─Conv2d: 2-33 [12, 128, 32, 32] 147,584
│ └─BatchNorm2d: 2-34 [12, 128, 32, 32] 256
│ └─ReLU: 2-35 [12, 128, 32, 32] --
├─Sequential: 1-11 [12, 128, 64, 64] --
│ └─ConvTranspose2d: 2-36 [12, 128, 64, 64] 65,664
│ └─BatchNorm2d: 2-37 [12, 128, 64, 64] 256
│ └─ReLU: 2-38 [12, 128, 64, 64] --
├─Sequential: 1-12 [12, 64, 64, 64] --
│ └─Conv2d: 2-39 [12, 64, 64, 64] 110,656
│ └─BatchNorm2d: 2-40 [12, 64, 64, 64] 128
│ └─ReLU: 2-41 [12, 64, 64, 64] --
│ └─Conv2d: 2-42 [12, 64, 64, 64] 36,928
│ └─BatchNorm2d: 2-43 [12, 64, 64, 64] 128
│ └─ReLU: 2-44 [12, 64, 64, 64] --
├─Sequential: 1-13 [12, 64, 128, 128] --
│ └─ConvTranspose2d: 2-45 [12, 64, 128, 128] 16,448
│ └─BatchNorm2d: 2-46 [12, 64, 128, 128] 128
│ └─ReLU: 2-47 [12, 64, 128, 128] --
├─Sequential: 1-14 [12, 32, 128, 128] --
│ └─Conv2d: 2-48 [12, 32, 128, 128] 36,896
│ └─BatchNorm2d: 2-49 [12, 32, 128, 128] 64
│ └─ReLU: 2-50 [12, 32, 128, 128] --
│ └─Conv2d: 2-51 [12, 32, 128, 128] 9,248
│ └─BatchNorm2d: 2-52 [12, 32, 128, 128] 64
│ └─ReLU: 2-53 [12, 32, 128, 128] --
├─Sequential: 1-15 [12, 32, 256, 256] --
│ └─ConvTranspose2d: 2-54 [12, 32, 256, 256] 4,128
│ └─BatchNorm2d: 2-55 [12, 32, 256, 256] 64
│ └─ReLU: 2-56 [12, 32, 256, 256] --
├─Sequential: 1-16 [12, 16, 256, 256] --
│ └─Conv2d: 2-57 [12, 16, 256, 256] 6,928
│ └─BatchNorm2d: 2-58 [12, 16, 256, 256] 32
│ └─ReLU: 2-59 [12, 16, 256, 256] --
│ └─Conv2d: 2-60 [12, 16, 256, 256] 2,320
│ └─BatchNorm2d: 2-61 [12, 16, 256, 256] 32
│ └─ReLU: 2-62 [12, 16, 256, 256] --
├─Conv2d: 1-17 [12, 10, 256, 256] 170
==========================================================================================
Total params: 16,247,074
Trainable params: 4,557,562
Non-trainable params: 11,689,512
Total mult-adds (G): 84.99
==========================================================================================
Input size (MB): 9.44
Forward/backward pass size (MB): 2648.70
Params size (MB): 62.94
Estimated Total Size (MB): 2721.08
==========================================================================================
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
myUNetResNet [12, 10, 256, 256] 513,000
├─Sequential: 1-1 [12, 16, 256, 256] --
│ └─Conv2d: 2-1 [12, 16, 256, 256] 448
│ └─BatchNorm2d: 2-2 [12, 16, 256, 256] 32
│ └─ReLU: 2-3 [12, 16, 256, 256] --
│ └─Conv2d: 2-4 [12, 16, 256, 256] 2,320
│ └─BatchNorm2d: 2-5 [12, 16, 256, 256] 32
│ └─ReLU: 2-6 [12, 16, 256, 256] --
├─Sequential: 1-2 [12, 64, 128, 128] --
│ └─Conv2d: 2-7 [12, 64, 128, 128] (9,408)
│ └─BatchNorm2d: 2-8 [12, 64, 128, 128] (128)
│ └─ReLU: 2-9 [12, 64, 128, 128] --
├─Sequential: 1-3 [12, 64, 64, 64] --
│ └─MaxPool2d: 2-10 [12, 64, 64, 64] --
│ └─Sequential: 2-11 [12, 64, 64, 64] --
│ │ └─BasicBlock: 3-1 [12, 64, 64, 64] (73,984)
│ │ └─BasicBlock: 3-2 [12, 64, 64, 64] (73,984)
├─Sequential: 1-4 [12, 128, 32, 32] --
│ └─BasicBlock: 2-12 [12, 128, 32, 32] --
│ │ └─Conv2d: 3-3 [12, 128, 32, 32] (73,728)
│ │ └─BatchNorm2d: 3-4 [12, 128, 32, 32] (256)
│ │ └─ReLU: 3-5 [12, 128, 32, 32] --
│ │ └─Conv2d: 3-6 [12, 128, 32, 32] (147,456)
│ │ └─BatchNorm2d: 3-7 [12, 128, 32, 32] (256)
│ │ └─Sequential: 3-8 [12, 128, 32, 32] (8,448)
│ │ └─ReLU: 3-9 [12, 128, 32, 32] --
│ └─BasicBlock: 2-13 [12, 128, 32, 32] --
│ │ └─Conv2d: 3-10 [12, 128, 32, 32] (147,456)
│ │ └─BatchNorm2d: 3-11 [12, 128, 32, 32] (256)
│ │ └─ReLU: 3-12 [12, 128, 32, 32] --
│ │ └─Conv2d: 3-13 [12, 128, 32, 32] (147,456)
│ │ └─BatchNorm2d: 3-14 [12, 128, 32, 32] (256)
│ │ └─ReLU: 3-15 [12, 128, 32, 32] --
├─Sequential: 1-5 [12, 256, 16, 16] --
│ └─BasicBlock: 2-14 [12, 256, 16, 16] --
│ │ └─Conv2d: 3-16 [12, 256, 16, 16] (294,912)
│ │ └─BatchNorm2d: 3-17 [12, 256, 16, 16] (512)
│ │ └─ReLU: 3-18 [12, 256, 16, 16] --
│ │ └─Conv2d: 3-19 [12, 256, 16, 16] (589,824)
│ │ └─BatchNorm2d: 3-20 [12, 256, 16, 16] (512)
│ │ └─Sequential: 3-21 [12, 256, 16, 16] (33,280)
│ │ └─ReLU: 3-22 [12, 256, 16, 16] --
│ └─BasicBlock: 2-15 [12, 256, 16, 16] --
│ │ └─Conv2d: 3-23 [12, 256, 16, 16] (589,824)
│ │ └─BatchNorm2d: 3-24 [12, 256, 16, 16] (512)
│ │ └─ReLU: 3-25 [12, 256, 16, 16] --
│ │ └─Conv2d: 3-26 [12, 256, 16, 16] (589,824)
│ │ └─BatchNorm2d: 3-27 [12, 256, 16, 16] (512)
│ │ └─ReLU: 3-28 [12, 256, 16, 16] --
├─Sequential: 1-6 [12, 512, 8, 8] --
│ └─BasicBlock: 2-16 [12, 512, 8, 8] --
│ │ └─Conv2d: 3-29 [12, 512, 8, 8] (1,179,648)
│ │ └─BatchNorm2d: 3-30 [12, 512, 8, 8] (1,024)
│ │ └─ReLU: 3-31 [12, 512, 8, 8] --
│ │ └─Conv2d: 3-32 [12, 512, 8, 8] (2,359,296)
│ │ └─BatchNorm2d: 3-33 [12, 512, 8, 8] (1,024)
│ │ └─Sequential: 3-34 [12, 512, 8, 8] (132,096)
│ │ └─ReLU: 3-35 [12, 512, 8, 8] --
│ └─BasicBlock: 2-17 [12, 512, 8, 8] --
│ │ └─Conv2d: 3-36 [12, 512, 8, 8] (2,359,296)
│ │ └─BatchNorm2d: 3-37 [12, 512, 8, 8] (1,024)
│ │ └─ReLU: 3-38 [12, 512, 8, 8] --
│ │ └─Conv2d: 3-39 [12, 512, 8, 8] (2,359,296)
│ │ └─BatchNorm2d: 3-40 [12, 512, 8, 8] (1,024)
│ │ └─ReLU: 3-41 [12, 512, 8, 8] --
├─Sequential: 1-7 [12, 512, 16, 16] --
│ └─ConvTranspose2d: 2-18 [12, 512, 16, 16] 1,049,088
│ └─BatchNorm2d: 2-19 [12, 512, 16, 16] 1,024
│ └─ReLU: 2-20 [12, 512, 16, 16] --
├─Sequential: 1-8 [12, 256, 16, 16] --
│ └─Conv2d: 2-21 [12, 256, 16, 16] 1,769,728
│ └─BatchNorm2d: 2-22 [12, 256, 16, 16] 512
│ └─ReLU: 2-23 [12, 256, 16, 16] --
│ └─Conv2d: 2-24 [12, 256, 16, 16] 590,080
│ └─BatchNorm2d: 2-25 [12, 256, 16, 16] 512
│ └─ReLU: 2-26 [12, 256, 16, 16] --
├─Sequential: 1-9 [12, 256, 32, 32] --
│ └─ConvTranspose2d: 2-27 [12, 256, 32, 32] 262,400
│ └─BatchNorm2d: 2-28 [12, 256, 32, 32] 512
│ └─ReLU: 2-29 [12, 256, 32, 32] --
├─Sequential: 1-10 [12, 128, 32, 32] --
│ └─Conv2d: 2-30 [12, 128, 32, 32] 442,496
│ └─BatchNorm2d: 2-31 [12, 128, 32, 32] 256
│ └─ReLU: 2-32 [12, 128, 32, 32] --
│ └─Conv2d: 2-33 [12, 128, 32, 32] 147,584
│ └─BatchNorm2d: 2-34 [12, 128, 32, 32] 256
│ └─ReLU: 2-35 [12, 128, 32, 32] --
├─Sequential: 1-11 [12, 128, 64, 64] --
│ └─ConvTranspose2d: 2-36 [12, 128, 64, 64] 65,664
│ └─BatchNorm2d: 2-37 [12, 128, 64, 64] 256
│ └─ReLU: 2-38 [12, 128, 64, 64] --
├─Sequential: 1-12 [12, 64, 64, 64] --
│ └─Conv2d: 2-39 [12, 64, 64, 64] 110,656
│ └─BatchNorm2d: 2-40 [12, 64, 64, 64] 128
│ └─ReLU: 2-41 [12, 64, 64, 64] --
│ └─Conv2d: 2-42 [12, 64, 64, 64] 36,928
│ └─BatchNorm2d: 2-43 [12, 64, 64, 64] 128
│ └─ReLU: 2-44 [12, 64, 64, 64] --
├─Sequential: 1-13 [12, 64, 128, 128] --
│ └─ConvTranspose2d: 2-45 [12, 64, 128, 128] 16,448
│ └─BatchNorm2d: 2-46 [12, 64, 128, 128] 128
│ └─ReLU: 2-47 [12, 64, 128, 128] --
├─Sequential: 1-14 [12, 32, 128, 128] --
│ └─Conv2d: 2-48 [12, 32, 128, 128] 36,896
│ └─BatchNorm2d: 2-49 [12, 32, 128, 128] 64
│ └─ReLU: 2-50 [12, 32, 128, 128] --
│ └─Conv2d: 2-51 [12, 32, 128, 128] 9,248
│ └─BatchNorm2d: 2-52 [12, 32, 128, 128] 64
│ └─ReLU: 2-53 [12, 32, 128, 128] --
├─Sequential: 1-15 [12, 32, 256, 256] --
│ └─ConvTranspose2d: 2-54 [12, 32, 256, 256] 4,128
│ └─BatchNorm2d: 2-55 [12, 32, 256, 256] 64
│ └─ReLU: 2-56 [12, 32, 256, 256] --
├─Sequential: 1-16 [12, 16, 256, 256] --
│ └─Conv2d: 2-57 [12, 16, 256, 256] 6,928
│ └─BatchNorm2d: 2-58 [12, 16, 256, 256] 32
│ └─ReLU: 2-59 [12, 16, 256, 256] --
│ └─Conv2d: 2-60 [12, 16, 256, 256] 2,320
│ └─BatchNorm2d: 2-61 [12, 16, 256, 256] 32
│ └─ReLU: 2-62 [12, 16, 256, 256] --
├─Conv2d: 1-17 [12, 10, 256, 256] 170
==========================================================================================
Total params: 16,247,074
Trainable params: 4,557,562
Non-trainable params: 11,689,512
Total mult-adds (G): 84.99
==========================================================================================
Input size (MB): 9.44
Forward/backward pass size (MB): 2648.70
Params size (MB): 62.94
Estimated Total Size (MB): 2721.08
==========================================================================================
Concluding Remarks
You can now define a basic UNet architecture, train a UNet model, and define UNet architectures that make use of common CNN architectures as the backbone or encoder and can accept pre-trained weights. However, there are other semantic segmentation architectures that are more complex and difficult to build from scratch. Also, you may want to be able to use a wide variety of backbones in a variety of different semantic segmentation architectures. In the next module, we will explore the Segmentation Models package, which builds on PyTorch and allows for using many different semantic segmentation architectures, backbone encoders, and pre-trained weights without having to build them or define the model or components on your own.