SegFormer Backbone
Implement UNet with SegFormer Backbone
Introduction
The Segmentation Models package provides a variety of semantic segmentation architectures along with CNN backbones that can be used as the encoder. These backbones also have pre-trained weights that can be used to implement transfer learning. Most of the available backbones are traditional CNN architectures, such as ResNet or VGGNet. However, this package does provide an implementation of the transformer-based SegFormer architecture, including the B0 through B5 implementations, which vary from ~3 to ~81 million trainable parameters. These SegFormer-based encoders also have pre-trained parameters available based on ImageNet.
In this module, we will explore classifying the Landcover.ai dataset using a UNet architecture with a SegFormer backbone. In other words, we will implement UNet with an encoder based on a transformer, as opposed to CNN, architecture. As you will see, this requires very little augmentation of the code, which is one of the great benefits of the Segmentation Models package.
Preparation
The preparation steps here are the same as those implemented in the train UNet module since the same dataset is used: Landcover.ai.
- The needed libraries are imported.
- The device is defined as the GPU.
- The training, validation, and test sets data frames are read in.
- A DataSet subclass is defined based on PyTorch’s DataSet class.
- Transformations/augmentations are defined.
- DataSets and DataLoaders are instantiated.
Please see the Train UNet module for more details. Again, the same process is implemented here as was implemented in that module.
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import os
import math
import cv2
import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import albumentations as A
import segmentation_models_pytorch as smp
import torchmetrics as tm
import rasterio as rio
from rasterio.plot import show
import torchinfo
= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device print(device)
cuda:0
= ['background', 'building', 'woodlands', 'water', 'road']
CLASSES = "C:/myFiles/work/dl/lancoverai/" OUTPUT_DIR
= pd.read_csv("C:/myFiles/work/dl/landcover.ai.v1/train.txt", header=None, names=["file"])
trainDF "img"] = OUTPUT_DIR + trainDF['file'] + ".jpg"
trainDF["mask"] = OUTPUT_DIR + trainDF['file'] + "_m.png" trainDF[
= pd.read_csv("C:/myFiles/work/dl/landcover.ai.v1/val.txt", header=None, names=["file"])
valDF "img"] = OUTPUT_DIR + valDF['file'] + ".jpg"
valDF["mask"] = OUTPUT_DIR + valDF['file'] + "_m.png" valDF[
= pd.read_csv("C:/myFiles/work/dl/landcover.ai.v1/test.txt", header=None, names=["file"])
testDF "img"] = OUTPUT_DIR + testDF['file'] + ".jpg"
testDF["mask"] = OUTPUT_DIR + testDF['file'] + "_m.png" testDF[
# Subclass and define custom dataset ===========================
class MultiClassSegDataset(Dataset):
def __init__(self, df, transform=None,):
self.df = df
self.transform = transform
def __getitem__(self, idx):
= self.df.iloc[idx, 1]
image_name = self.df.iloc[idx, 2]
mask_name = cv2.imread(image_name)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.imread(mask_name, cv2.IMREAD_UNCHANGED)
mask = image.astype('uint8')
image = mask[:,:,0]
mask if(self.transform is not None):
= self.transform(image=image, mask=mask)
transformed = transformed["image"]
image = transformed["mask"]
mask = torch.from_numpy(image)
image = torch.from_numpy(mask)
mask = image.permute(2, 0, 1)
image = image.float()/255
image = mask.long()
mask else:
= torch.from_numpy(image)
image = torch.from_numpy(mask)
mask = image.permute(2, 0, 1)
image = image.float()/255
image = mask.long()
mask return image, mask
def __len__(self):
return len(self.df)
= A.Compose(
test_transform =512, min_width=512, border_mode=4), A.Resize(512, 512),]
[A.PadIfNeeded(min_height )
= A.Compose(
train_transform
[=512, min_width=512, border_mode=4),
A.PadIfNeeded(min_height512, 512),
A.Resize(=0.3, contrast_limit=0.3, p=0.5),
A.RandomBrightnessContrast(brightness_limit=0.5),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=3, always_apply=False, p=0.1),
A.MedianBlur(blur_limit
] )
= MultiClassSegDataset(trainDF, transform=train_transform)
trainDS = MultiClassSegDataset(valDF, transform=test_transform)
valDS print("Number of Training Samples: " + str(len(trainDS)) + " Number of Testing Samples: " + str(len(valDS)))
Number of Training Samples: 7470 Number of Testing Samples: 1602
= DataLoader(trainDS, batch_size=8, shuffle=True, sampler=None,
trainDL =None, num_workers=0, collate_fn=None,
batch_sampler=False, drop_last=True, timeout=0,
pin_memory=None)
worker_init_fn= DataLoader(valDS, batch_size=8, shuffle=False, sampler=None,
valDL =None, num_workers=0, collate_fn=None,
batch_sampler=False, drop_last=True, timeout=0,
pin_memory=None) worker_init_fn
Define Model
Defining a UNet model with a SegFormer backbone using Segmentation Models is very similar to defining a model using a CNN-based backbone. We simply need to provide the name of the desired SegFormer architecture as the argument for the encoder_name parameter. In the example, I am using the “mit_b2” version, which has ~24 million parameters. I am also initializing the model using parameters learned from ImageNet and not applying an activation function. The architecture is defined using the same Segmentation Models syntax that was used in the Segmentation Models module. Note that we are differentiating 5 classes, and the input data have three channels.
= "mit_b2"
encoder = "imagenet"
encoder_weights = None activation
= smp.Unet(
model =encoder,
encoder_name=encoder_weights,
encoder_weights=5,
classes=activation,
activation=3
in_channels ).to(device)
Train Model
Training the model requires the following preparation steps:
- Define and instantiate a loss metric. In this case we will use a multiclass Dice loss.
- Define accuracy assessment metrics to monitor during the learning process.
- Freeze the encoder. I have chosen not to train the encoder to speed up the learning process. Training the encoder may allow for improved model performance; however, this could also lead to overfitting, especially given the large number of trainable parameters for the SegFormer backbone. If you are curious, feel free to train the model with an unfrozen encoder.
- Define an optimizer. We will use AdamW with the default learning rate.
- Define the number of training epochs, in this case 50, and the location at which to save the logs and final model.
= smp.losses.DiceLoss(mode="multiclass", classes = 5, from_logits=True, smooth=1e-8) criterion
= tm.Accuracy(task="multiclass", average="micro", num_classes=5).to(device)
acc = tm.F1Score(task="multiclass", average="macro", num_classes=5).to(device)
f1 = tm.CohenKappa(task="multiclass", average = "micro", num_classes=5).to(device) kappa
#https://github.com/qubvel/segmentation_models.pytorch/issues/79
def freeze_encoder(model):
for child in model.encoder.children():
for param in child.parameters():
= False
param.requires_grad return
def unfreeze(model):
for child in model.children():
for param in child.parameters():
= True
param.requires_grad return
freeze_encoder(model)
= torch.optim.AdamW(model.parameters()) optimizer
= 25
epochs = "C:/myFiles/work/dl/landcoverai_segformer/" saveFolder
Before training the model, I also print a summary of the model using TorchInfo. You can see that the encoder is now very different from a CNN architecture since it is now made up of the components of SegFormer. Also, you can see that many of the parameters are non trainable since a chose to freeze them.
8, 3, 512, 512)) torchinfo.summary(model, (
===============================================================================================
Layer (type:depth-idx) Output Shape Param #
===============================================================================================
Unet [8, 5, 512, 512] --
├─MixVisionTransformerEncoder: 1-1 [8, 3, 512, 512] --
│ └─OverlapPatchEmbed: 2-1 [8, 16384, 64] --
│ │ └─Conv2d: 3-1 [8, 64, 128, 128] (9,472)
│ │ └─LayerNorm: 3-2 [8, 16384, 64] (128)
│ └─ModuleList: 2-2 -- --
│ │ └─Block: 3-3 [8, 16384, 64] (314,880)
│ │ └─Block: 3-4 [8, 16384, 64] (314,880)
│ │ └─Block: 3-5 [8, 16384, 64] (314,880)
│ └─LayerNorm: 2-3 [8, 16384, 64] (128)
│ └─OverlapPatchEmbed: 2-4 [8, 4096, 128] --
│ │ └─Conv2d: 3-6 [8, 128, 64, 64] (73,856)
│ │ └─LayerNorm: 3-7 [8, 4096, 128] (256)
│ └─ModuleList: 2-5 -- --
│ │ └─Block: 3-8 [8, 4096, 128] (465,920)
│ │ └─Block: 3-9 [8, 4096, 128] (465,920)
│ │ └─Block: 3-10 [8, 4096, 128] (465,920)
│ │ └─Block: 3-11 [8, 4096, 128] (465,920)
│ └─LayerNorm: 2-6 [8, 4096, 128] (256)
│ └─OverlapPatchEmbed: 2-7 [8, 1024, 320] --
│ │ └─Conv2d: 3-12 [8, 320, 32, 32] (368,960)
│ │ └─LayerNorm: 3-13 [8, 1024, 320] (640)
│ └─ModuleList: 2-8 -- --
│ │ └─Block: 3-14 [8, 1024, 320] (1,656,320)
│ │ └─Block: 3-15 [8, 1024, 320] (1,656,320)
│ │ └─Block: 3-16 [8, 1024, 320] (1,656,320)
│ │ └─Block: 3-17 [8, 1024, 320] (1,656,320)
│ │ └─Block: 3-18 [8, 1024, 320] (1,656,320)
│ │ └─Block: 3-19 [8, 1024, 320] (1,656,320)
│ └─LayerNorm: 2-9 [8, 1024, 320] (640)
│ └─OverlapPatchEmbed: 2-10 [8, 256, 512] --
│ │ └─Conv2d: 3-20 [8, 512, 16, 16] (1,475,072)
│ │ └─LayerNorm: 3-21 [8, 256, 512] (1,024)
│ └─ModuleList: 2-11 -- --
│ │ └─Block: 3-22 [8, 256, 512] (3,172,864)
│ │ └─Block: 3-23 [8, 256, 512] (3,172,864)
│ │ └─Block: 3-24 [8, 256, 512] (3,172,864)
│ └─LayerNorm: 2-12 [8, 256, 512] (1,024)
├─UnetDecoder: 1-2 [8, 16, 512, 512] --
│ └─Identity: 2-13 [8, 512, 16, 16] --
│ └─ModuleList: 2-14 -- --
│ │ └─DecoderBlock: 3-25 [8, 256, 32, 32] 2,507,776
│ │ └─DecoderBlock: 3-26 [8, 128, 64, 64] 590,336
│ │ └─DecoderBlock: 3-27 [8, 64, 128, 128] 147,712
│ │ └─DecoderBlock: 3-28 [8, 32, 256, 256] 27,776
│ │ └─DecoderBlock: 3-29 [8, 16, 512, 512] 6,976
├─SegmentationHead: 1-3 [8, 5, 512, 512] --
│ └─Conv2d: 2-15 [8, 5, 512, 512] 725
│ └─Identity: 2-16 [8, 5, 512, 512] --
│ └─Activation: 2-17 [8, 5, 512, 512] --
│ │ └─Identity: 3-30 [8, 5, 512, 512] --
===============================================================================================
Total params: 27,477,589
Trainable params: 3,281,301
Non-trainable params: 24,196,288
Total mult-adds (G): 110.78
===============================================================================================
Input size (MB): 25.17
Forward/backward pass size (MB): 9101.64
Params size (MB): 109.91
Estimated Total Size (MB): 9236.72
===============================================================================================
===============================================================================================
Layer (type:depth-idx) Output Shape Param #
===============================================================================================
Unet [8, 5, 512, 512] --
├─MixVisionTransformerEncoder: 1-1 [8, 3, 512, 512] --
│ └─OverlapPatchEmbed: 2-1 [8, 16384, 64] --
│ │ └─Conv2d: 3-1 [8, 64, 128, 128] (9,472)
│ │ └─LayerNorm: 3-2 [8, 16384, 64] (128)
│ └─ModuleList: 2-2 -- --
│ │ └─Block: 3-3 [8, 16384, 64] (314,880)
│ │ └─Block: 3-4 [8, 16384, 64] (314,880)
│ │ └─Block: 3-5 [8, 16384, 64] (314,880)
│ └─LayerNorm: 2-3 [8, 16384, 64] (128)
│ └─OverlapPatchEmbed: 2-4 [8, 4096, 128] --
│ │ └─Conv2d: 3-6 [8, 128, 64, 64] (73,856)
│ │ └─LayerNorm: 3-7 [8, 4096, 128] (256)
│ └─ModuleList: 2-5 -- --
│ │ └─Block: 3-8 [8, 4096, 128] (465,920)
│ │ └─Block: 3-9 [8, 4096, 128] (465,920)
│ │ └─Block: 3-10 [8, 4096, 128] (465,920)
│ │ └─Block: 3-11 [8, 4096, 128] (465,920)
│ └─LayerNorm: 2-6 [8, 4096, 128] (256)
│ └─OverlapPatchEmbed: 2-7 [8, 1024, 320] --
│ │ └─Conv2d: 3-12 [8, 320, 32, 32] (368,960)
│ │ └─LayerNorm: 3-13 [8, 1024, 320] (640)
│ └─ModuleList: 2-8 -- --
│ │ └─Block: 3-14 [8, 1024, 320] (1,656,320)
│ │ └─Block: 3-15 [8, 1024, 320] (1,656,320)
│ │ └─Block: 3-16 [8, 1024, 320] (1,656,320)
│ │ └─Block: 3-17 [8, 1024, 320] (1,656,320)
│ │ └─Block: 3-18 [8, 1024, 320] (1,656,320)
│ │ └─Block: 3-19 [8, 1024, 320] (1,656,320)
│ └─LayerNorm: 2-9 [8, 1024, 320] (640)
│ └─OverlapPatchEmbed: 2-10 [8, 256, 512] --
│ │ └─Conv2d: 3-20 [8, 512, 16, 16] (1,475,072)
│ │ └─LayerNorm: 3-21 [8, 256, 512] (1,024)
│ └─ModuleList: 2-11 -- --
│ │ └─Block: 3-22 [8, 256, 512] (3,172,864)
│ │ └─Block: 3-23 [8, 256, 512] (3,172,864)
│ │ └─Block: 3-24 [8, 256, 512] (3,172,864)
│ └─LayerNorm: 2-12 [8, 256, 512] (1,024)
├─UnetDecoder: 1-2 [8, 16, 512, 512] --
│ └─Identity: 2-13 [8, 512, 16, 16] --
│ └─ModuleList: 2-14 -- --
│ │ └─DecoderBlock: 3-25 [8, 256, 32, 32] 2,507,776
│ │ └─DecoderBlock: 3-26 [8, 128, 64, 64] 590,336
│ │ └─DecoderBlock: 3-27 [8, 64, 128, 128] 147,712
│ │ └─DecoderBlock: 3-28 [8, 32, 256, 256] 27,776
│ │ └─DecoderBlock: 3-29 [8, 16, 512, 512] 6,976
├─SegmentationHead: 1-3 [8, 5, 512, 512] --
│ └─Conv2d: 2-15 [8, 5, 512, 512] 725
│ └─Identity: 2-16 [8, 5, 512, 512] --
│ └─Activation: 2-17 [8, 5, 512, 512] --
│ │ └─Identity: 3-30 [8, 5, 512, 512] --
===============================================================================================
Total params: 27,477,589
Trainable params: 3,281,301
Non-trainable params: 24,196,288
Total mult-adds (G): 110.78
===============================================================================================
Input size (MB): 25.17
Forward/backward pass size (MB): 9101.64
Params size (MB): 109.91
Estimated Total Size (MB): 9236.72
===============================================================================================
We are now ready to train the model using our standard training loop. A few notes:
- I am training for a total of 50 epochs with a mini-batch size of 8 chips.
- The validation set will be predicted at the end of each training epoch in order to assess the model and monitor for overfitting.
- I am using the Dice loss and monitoring overall accuracy, class-aggregated macro average F1-score, and the Kappa statistic.
- At the end of each training epoch, the model will only be saved to disk if there is an improvement in the F1-score for the validation data.
If you execute the training loop, note that it will take several hours to complete. I have provided a trained model file if you would like to run the following model assessment code without training the model.
= []
eNum = []
t_loss = []
t_acc = []
t_f1 = []
t_kappa = []
v_loss = []
v_acc = []
v_f1 = []
v_kappa
= 0.0
f1VMax
# Loop over epochs
for epoch in range(1, epochs+1):
model.train()= 0.0
running_loss # Loop over training batches
for batch_idx, (inputs, targets) in enumerate(trainDL):
# Get data and move to device
= inputs.to(device), targets.to(device)
inputs, targets
# Clear gradients
optimizer.zero_grad()# Predict data
= model(inputs)
outputs # Calculate loss
= criterion(outputs, targets)
loss
# Calculate metrics
= acc(outputs, targets)
accT = f1(outputs, targets)
f1T = kappa(outputs, targets)
kappaT
# Backpropagate
loss.backward()
# Update parameters
optimizer.step()
#Update running with batch results
+= loss.item()
running_loss
# Accumulate loss and metrics at end of training epoch
= running_loss/len(trainDL)
epoch_loss = acc.compute()
accT = f1.compute()
f1T = kappa.compute()
kappaT
# Print Losses and metrics at end of each training epoch
print(f'Epoch: {epoch}, Training Loss: {epoch_loss:.4f}, Training Accuracy: {accT:.4f}, Training F1: {f1T:.4f}, Training Kappa: {kappaT:.4f}')
# Append results
eNum.append(epoch)
t_loss.append(epoch_loss)
t_acc.append(accT.detach().cpu().numpy())
t_f1.append(f1T.detach().cpu().numpy())
t_kappa.append(kappaT.detach().cpu().numpy())
# Reset metrics
acc.reset()
f1.reset()
kappa.reset()
# loop over validation batches
with torch.no_grad():
#Initialize running validation loss
= 0.0
running_loss_v for batch_idx, (inputs, targets) in enumerate(valDL):
# Get data and move to device
= inputs.to(device), targets.to(device)
inputs, targets
# Predict data
= model(inputs)
outputs # Calculate validation loss
= criterion(outputs, targets)
loss_v
#Update running with batch results
+= loss_v.item()
running_loss_v
# Calculate metrics
= acc(outputs, targets)
accV = f1(outputs, targets)
f1V = kappa(outputs, targets)
kappaV
#Accumulate loss and metrics at end of validation epoch
= running_loss_v/len(valDL)
epoch_loss_v = acc.compute()
accV = f1.compute()
f1V = kappa.compute()
kappaV
# Print validation loss and metrics
print(f'Validation Loss: {epoch_loss_v:.4f}, Validation Accuracy: {accV:.4f}, Validation F1: {f1V:.4f}, Validation Kappa: {kappaV:.4f}')
# Append results
v_loss.append(epoch_loss_v)
v_acc.append(accV.detach().cpu().numpy())
v_f1.append(f1V.detach().cpu().numpy())
v_kappa.append(kappaV.detach().cpu().numpy())
# Reset metrics
acc.reset()
f1.reset()
kappa.reset()
# Save model if validation F1-score improves
= f1V.detach().cpu().numpy()
f1V2 if f1V2 > f1VMax:
= f1V2
f1VMax + 'lcai_SegFormer_model.pt')
torch.save(model.state_dict(), saveFolder print(f'Model saved for epoch {epoch}.')
= pd.Series(eNum, name="epoch")
SeNum = pd.Series(t_loss, name="training_loss")
St_loss = pd.Series(t_acc, name="training_accuracy")
St_acc = pd.Series(t_f1, name="training_f1")
St_f1 = pd.Series(t_kappa, name="training_kappa")
St_kappa = pd.Series(v_loss, name="val_loss")
Sv_loss = pd.Series(v_acc, name="val_accuracy")
Sv_acc = pd.Series(v_f1, name="val_f1")
Sv_f1 = pd.Series(v_kappa, name="val_kappa")
Sv_kappa = pd.concat([SeNum, St_loss, St_acc, St_f1, St_kappa, Sv_loss, Sv_acc, Sv_f1, Sv_kappa], axis=1)
resultsDF
+"resultsLCAISegFormer.csv") resultsDF.to_csv(saveFolder
Model Assessment
The remainder of the code is used to assess the trained model using the withheld testing data. This includes the following:
- Read in the training logs and plot the training and validation losses and F1-scores.
- Re-instantiate the model and load in the saved parameters.
- Define the test DataSet and DataLoader.
- Instantiate new assessment metrics with no aggregation so that metrics are obtained separately for each class.
- Execute a validation loop and print the obtained assessment metrics.
= pd.read_csv("data/models/resultsLCAISegFormer.csv") resultsDF
'figure.figsize'] = [10, 10]
plt.rcParams[= resultsDF.plot(x='epoch', y="training_loss")
firstPlot ='epoch', y="val_loss", ax=firstPlot)
resultsDF.plot(x plt.show()
'figure.figsize'] = [10, 10]
plt.rcParams[= resultsDF.plot(x='epoch', y="training_f1")
firstPlot ='epoch', y="val_f1", ax=firstPlot)
resultsDF.plot(x plt.show()
= smp.Unet(
model =encoder,
encoder_name=encoder_weights,
encoder_weights=5,
classes=activation,
activation=3
in_channels ).to(device)
+ "lcai_SegFormer_model.pt")) model.load_state_dict(torch.load(saveFolder
<All keys matched successfully>
= MultiClassSegDataset(testDF, transform=test_transform) testDS
= DataLoader(testDS, batch_size=8, shuffle=False, sampler=None,
testDL =None, num_workers=0, collate_fn=None,
batch_sampler=False, drop_last=True, timeout=0,
pin_memory=None) worker_init_fn
= tm.Accuracy(task="multiclass", num_classes=5).to(device)
acc = tm.F1Score(task="multiclass", num_classes=5, average='none').to(device)
f1 = tm.Recall(task="multiclass", num_classes=5, average='none').to(device)
recall = tm.Precision(task="multiclass", num_classes=5, average='none').to(device)
precision = tm.CohenKappa(task="multiclass", num_classes=5).to(device)
kappa = tm.ConfusionMatrix(task="multiclass", num_classes=5).to(device) cm
eval() model.
Unet(
(encoder): MixVisionTransformerEncoder(
(patch_embed1): OverlapPatchEmbed(
(proj): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
(norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)
(patch_embed2): OverlapPatchEmbed(
(proj): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)
(patch_embed3): OverlapPatchEmbed(
(proj): Conv2d(128, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
)
(patch_embed4): OverlapPatchEmbed(
(proj): Conv2d(320, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(block1): ModuleList(
(0): Block(
(norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=64, out_features=64, bias=True)
(kv): Linear(in_features=64, out_features=128, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=64, out_features=64, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(sr): Conv2d(64, 64, kernel_size=(8, 8), stride=(8, 8))
(norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)
(drop_path): Identity()
(norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=64, out_features=256, bias=True)
(dwconv): DWConv(
(dwconv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
)
(act): GELU(approximate='none')
(fc2): Linear(in_features=256, out_features=64, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(1): Block(
(norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=64, out_features=64, bias=True)
(kv): Linear(in_features=64, out_features=128, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=64, out_features=64, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(sr): Conv2d(64, 64, kernel_size=(8, 8), stride=(8, 8))
(norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)
(drop_path): DropPath(drop_prob=0.007)
(norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=64, out_features=256, bias=True)
(dwconv): DWConv(
(dwconv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
)
(act): GELU(approximate='none')
(fc2): Linear(in_features=256, out_features=64, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(2): Block(
(norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=64, out_features=64, bias=True)
(kv): Linear(in_features=64, out_features=128, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=64, out_features=64, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(sr): Conv2d(64, 64, kernel_size=(8, 8), stride=(8, 8))
(norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)
(drop_path): DropPath(drop_prob=0.013)
(norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=64, out_features=256, bias=True)
(dwconv): DWConv(
(dwconv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
)
(act): GELU(approximate='none')
(fc2): Linear(in_features=256, out_features=64, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
)
(norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
(block2): ModuleList(
(0): Block(
(norm1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=128, out_features=128, bias=True)
(kv): Linear(in_features=128, out_features=256, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=128, out_features=128, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(sr): Conv2d(128, 128, kernel_size=(4, 4), stride=(4, 4))
(norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)
(drop_path): DropPath(drop_prob=0.020)
(norm2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=128, out_features=512, bias=True)
(dwconv): DWConv(
(dwconv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512)
)
(act): GELU(approximate='none')
(fc2): Linear(in_features=512, out_features=128, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(1): Block(
(norm1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=128, out_features=128, bias=True)
(kv): Linear(in_features=128, out_features=256, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=128, out_features=128, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(sr): Conv2d(128, 128, kernel_size=(4, 4), stride=(4, 4))
(norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)
(drop_path): DropPath(drop_prob=0.027)
(norm2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=128, out_features=512, bias=True)
(dwconv): DWConv(
(dwconv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512)
)
(act): GELU(approximate='none')
(fc2): Linear(in_features=512, out_features=128, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(2): Block(
(norm1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=128, out_features=128, bias=True)
(kv): Linear(in_features=128, out_features=256, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=128, out_features=128, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(sr): Conv2d(128, 128, kernel_size=(4, 4), stride=(4, 4))
(norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)
(drop_path): DropPath(drop_prob=0.033)
(norm2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=128, out_features=512, bias=True)
(dwconv): DWConv(
(dwconv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512)
)
(act): GELU(approximate='none')
(fc2): Linear(in_features=512, out_features=128, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(3): Block(
(norm1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=128, out_features=128, bias=True)
(kv): Linear(in_features=128, out_features=256, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=128, out_features=128, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(sr): Conv2d(128, 128, kernel_size=(4, 4), stride=(4, 4))
(norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)
(drop_path): DropPath(drop_prob=0.040)
(norm2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=128, out_features=512, bias=True)
(dwconv): DWConv(
(dwconv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512)
)
(act): GELU(approximate='none')
(fc2): Linear(in_features=512, out_features=128, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
)
(norm2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
(block3): ModuleList(
(0): Block(
(norm1): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=320, out_features=320, bias=True)
(kv): Linear(in_features=320, out_features=640, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=320, out_features=320, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(sr): Conv2d(320, 320, kernel_size=(2, 2), stride=(2, 2))
(norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
)
(drop_path): DropPath(drop_prob=0.047)
(norm2): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=320, out_features=1280, bias=True)
(dwconv): DWConv(
(dwconv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1280)
)
(act): GELU(approximate='none')
(fc2): Linear(in_features=1280, out_features=320, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(1): Block(
(norm1): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=320, out_features=320, bias=True)
(kv): Linear(in_features=320, out_features=640, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=320, out_features=320, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(sr): Conv2d(320, 320, kernel_size=(2, 2), stride=(2, 2))
(norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
)
(drop_path): DropPath(drop_prob=0.053)
(norm2): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=320, out_features=1280, bias=True)
(dwconv): DWConv(
(dwconv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1280)
)
(act): GELU(approximate='none')
(fc2): Linear(in_features=1280, out_features=320, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(2): Block(
(norm1): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=320, out_features=320, bias=True)
(kv): Linear(in_features=320, out_features=640, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=320, out_features=320, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(sr): Conv2d(320, 320, kernel_size=(2, 2), stride=(2, 2))
(norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
)
(drop_path): DropPath(drop_prob=0.060)
(norm2): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=320, out_features=1280, bias=True)
(dwconv): DWConv(
(dwconv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1280)
)
(act): GELU(approximate='none')
(fc2): Linear(in_features=1280, out_features=320, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(3): Block(
(norm1): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=320, out_features=320, bias=True)
(kv): Linear(in_features=320, out_features=640, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=320, out_features=320, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(sr): Conv2d(320, 320, kernel_size=(2, 2), stride=(2, 2))
(norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
)
(drop_path): DropPath(drop_prob=0.067)
(norm2): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=320, out_features=1280, bias=True)
(dwconv): DWConv(
(dwconv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1280)
)
(act): GELU(approximate='none')
(fc2): Linear(in_features=1280, out_features=320, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(4): Block(
(norm1): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=320, out_features=320, bias=True)
(kv): Linear(in_features=320, out_features=640, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=320, out_features=320, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(sr): Conv2d(320, 320, kernel_size=(2, 2), stride=(2, 2))
(norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
)
(drop_path): DropPath(drop_prob=0.073)
(norm2): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=320, out_features=1280, bias=True)
(dwconv): DWConv(
(dwconv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1280)
)
(act): GELU(approximate='none')
(fc2): Linear(in_features=1280, out_features=320, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(5): Block(
(norm1): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=320, out_features=320, bias=True)
(kv): Linear(in_features=320, out_features=640, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=320, out_features=320, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(sr): Conv2d(320, 320, kernel_size=(2, 2), stride=(2, 2))
(norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
)
(drop_path): DropPath(drop_prob=0.080)
(norm2): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=320, out_features=1280, bias=True)
(dwconv): DWConv(
(dwconv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1280)
)
(act): GELU(approximate='none')
(fc2): Linear(in_features=1280, out_features=320, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
)
(norm3): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
(block4): ModuleList(
(0): Block(
(norm1): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=512, out_features=512, bias=True)
(kv): Linear(in_features=512, out_features=1024, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=512, out_features=512, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(drop_prob=0.087)
(norm2): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(dwconv): DWConv(
(dwconv): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2048)
)
(act): GELU(approximate='none')
(fc2): Linear(in_features=2048, out_features=512, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(1): Block(
(norm1): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=512, out_features=512, bias=True)
(kv): Linear(in_features=512, out_features=1024, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=512, out_features=512, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(drop_prob=0.093)
(norm2): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(dwconv): DWConv(
(dwconv): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2048)
)
(act): GELU(approximate='none')
(fc2): Linear(in_features=2048, out_features=512, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(2): Block(
(norm1): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(q): Linear(in_features=512, out_features=512, bias=True)
(kv): Linear(in_features=512, out_features=1024, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=512, out_features=512, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(drop_prob=0.100)
(norm2): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(dwconv): DWConv(
(dwconv): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2048)
)
(act): GELU(approximate='none')
(fc2): Linear(in_features=2048, out_features=512, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
)
(norm4): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
)
(decoder): UnetDecoder(
(center): Identity()
(blocks): ModuleList(
(0): DecoderBlock(
(conv1): Conv2dReLU(
(0): Conv2d(832, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention1): Attention(
(attention): Identity()
)
(conv2): Conv2dReLU(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention2): Attention(
(attention): Identity()
)
)
(1): DecoderBlock(
(conv1): Conv2dReLU(
(0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention1): Attention(
(attention): Identity()
)
(conv2): Conv2dReLU(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention2): Attention(
(attention): Identity()
)
)
(2): DecoderBlock(
(conv1): Conv2dReLU(
(0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention1): Attention(
(attention): Identity()
)
(conv2): Conv2dReLU(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention2): Attention(
(attention): Identity()
)
)
(3): DecoderBlock(
(conv1): Conv2dReLU(
(0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention1): Attention(
(attention): Identity()
)
(conv2): Conv2dReLU(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention2): Attention(
(attention): Identity()
)
)
(4): DecoderBlock(
(conv1): Conv2dReLU(
(0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention1): Attention(
(attention): Identity()
)
(conv2): Conv2dReLU(
(0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention2): Attention(
(attention): Identity()
)
)
)
)
(segmentation_head): SegmentationHead(
(0): Conv2d(16, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): Identity()
(2): Activation(
(activation): Identity()
)
)
)
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testDL):
= inputs.to(device), targets.to(device)
inputs, targets = model(inputs)
outputs #loss_v = criterion(outputs, targets)
= acc(outputs, targets)
accV = f1(outputs, targets)
f1V = recall(outputs, targets)
rV = precision(outputs, targets)
pV = kappa(outputs, targets)
kappaV = cm(outputs, targets)
cmV = acc.compute()
accV = f1.compute()
f1V = recall.compute()
rV = precision.compute()
pV = kappa.compute()
kappaV = cm.compute()
cmV
acc.reset()
f1.reset()
recall.reset()
precision.reset()
kappa.reset() cm.reset()
print(accV)
tensor(0.9335, device='cuda:0')
print(f1V)
tensor([0.9452, 0.8386, 0.9251, 0.9367, 0.7420], device='cuda:0')
print(rV)
tensor([0.9586, 0.8680, 0.9087, 0.9161, 0.6983], device='cuda:0')
print(pV)
tensor([0.9323, 0.8110, 0.9421, 0.9582, 0.7915], device='cuda:0')
print(kappaV)
tensor(0.8784, device='cuda:0')
print(cmV)
tensor([[229862911, 723735, 7160869, 812370, 1240456],
[ 462547, 3434455, 12798, 4232, 42522],
[ 12874525, 3325, 130652883, 156093, 91325],
[ 1720055, 369, 318527, 22342779, 6392],
[ 1646586, 72957, 544417, 959, 5242313]],
device='cuda:0')
Concluding Remarks
As this module demonstrates, it is possible to train a semantic segmentation architecture using the Segmentation Models library that uses a SegFormer-based encoder with very minimal changes to the code. This highlights one of the benefits of using the Segmentation Models package. Since transformer-based architectures may be able to capture spatial context over broader spatial extents, you may want to experiment with using a SegFormer backbone as opposed to one based on a CNN architecture. However, training an architecture that includes transformer components may require more training data in order to see improvement relative to a CNN-based architecture.