SegFormer Backbone

Implement UNet with SegFormer Backbone

Introduction

The Segmentation Models package provides a variety of semantic segmentation architectures along with CNN backbones that can be used as the encoder. These backbones also have pre-trained weights that can be used to implement transfer learning. Most of the available backbones are traditional CNN architectures, such as ResNet or VGGNet. However, this package does provide an implementation of the transformer-based SegFormer architecture, including the B0 through B5 implementations, which vary from ~3 to ~81 million trainable parameters. These SegFormer-based encoders also have pre-trained parameters available based on ImageNet.

In this module, we will explore classifying the Landcover.ai dataset using a UNet architecture with a SegFormer backbone. In other words, we will implement UNet with an encoder based on a transformer, as opposed to CNN, architecture. As you will see, this requires very little augmentation of the code, which is one of the great benefits of the Segmentation Models package.

Preparation

The preparation steps here are the same as those implemented in the train UNet module since the same dataset is used: Landcover.ai.

  1. The needed libraries are imported.
  2. The device is defined as the GPU.
  3. The training, validation, and test sets data frames are read in.
  4. A DataSet subclass is defined based on PyTorch’s DataSet class.
  5. Transformations/augmentations are defined.
  6. DataSets and DataLoaders are instantiated.

Please see the Train UNet module for more details. Again, the same process is implemented here as was implemented in that module.

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt 
import os
import math
import cv2

import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

import albumentations as A

import segmentation_models_pytorch as smp

import torchmetrics as tm

import rasterio as rio
from rasterio.plot import show

import torchinfo
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
cuda:0
CLASSES = ['background', 'building', 'woodlands', 'water', 'road']
OUTPUT_DIR = "C:/myFiles/work/dl/lancoverai/"
trainDF = pd.read_csv("C:/myFiles/work/dl/landcover.ai.v1/train.txt", header=None, names=["file"])
trainDF["img"] = OUTPUT_DIR + trainDF['file'] + ".jpg"
trainDF["mask"] = OUTPUT_DIR + trainDF['file'] + "_m.png"
valDF = pd.read_csv("C:/myFiles/work/dl/landcover.ai.v1/val.txt", header=None, names=["file"])
valDF["img"] = OUTPUT_DIR + valDF['file'] + ".jpg"
valDF["mask"] = OUTPUT_DIR + valDF['file'] + "_m.png"
testDF = pd.read_csv("C:/myFiles/work/dl/landcover.ai.v1/test.txt", header=None, names=["file"])
testDF["img"] = OUTPUT_DIR + testDF['file'] + ".jpg"
testDF["mask"] = OUTPUT_DIR + testDF['file'] + "_m.png"
# Subclass and define custom dataset ===========================
class MultiClassSegDataset(Dataset):
    
    def __init__(self, df, transform=None,):
        self.df = df
        self.transform = transform
    
    def __getitem__(self, idx):
        
        image_name = self.df.iloc[idx, 1]
        mask_name = self.df.iloc[idx, 2]
        image = cv2.imread(image_name)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_name, cv2.IMREAD_UNCHANGED)
        image = image.astype('uint8')
        mask = mask[:,:,0]
        if(self.transform is not None):
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
            image = torch.from_numpy(image)
            mask = torch.from_numpy(mask)   
            image = image.permute(2, 0, 1)
            image = image.float()/255
            mask = mask.long()
        else: 
            image = torch.from_numpy(image)
            mask = torch.from_numpy(mask)
            image = image.permute(2, 0, 1)
            image = image.float()/255
            mask = mask.long()
        return image, mask  
        
    def __len__(self):
        return len(self.df)
test_transform = A.Compose(
    [A.PadIfNeeded(min_height=512, min_width=512, border_mode=4), A.Resize(512, 512),]
)
train_transform = A.Compose(
    [
        A.PadIfNeeded(min_height=512, min_width=512, border_mode=4),
        A.Resize(512, 512),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.MedianBlur(blur_limit=3, always_apply=False, p=0.1),
    ]
)
trainDS = MultiClassSegDataset(trainDF, transform=train_transform)
valDS = MultiClassSegDataset(valDF, transform=test_transform)
print("Number of Training Samples: " + str(len(trainDS)) + " Number of Testing Samples: " + str(len(valDS)))
Number of Training Samples: 7470 Number of Testing Samples: 1602
trainDL = DataLoader(trainDS, batch_size=8, shuffle=True, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=True, timeout=0,
           worker_init_fn=None)
valDL =  DataLoader(valDS, batch_size=8, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=True, timeout=0,
           worker_init_fn=None)

Define Model

Defining a UNet model with a SegFormer backbone using Segmentation Models is very similar to defining a model using a CNN-based backbone. We simply need to provide the name of the desired SegFormer architecture as the argument for the encoder_name parameter. In the example, I am using the “mit_b2” version, which has ~24 million parameters. I am also initializing the model using parameters learned from ImageNet and not applying an activation function. The architecture is defined using the same Segmentation Models syntax that was used in the Segmentation Models module. Note that we are differentiating 5 classes, and the input data have three channels.

encoder = "mit_b2"
encoder_weights = "imagenet"
activation = None
model = smp.Unet(
    encoder_name=encoder, 
    encoder_weights=encoder_weights, 
    classes=5, 
    activation=activation,
    in_channels=3
).to(device)

Train Model

Training the model requires the following preparation steps:

  1. Define and instantiate a loss metric. In this case we will use a multiclass Dice loss.
  2. Define accuracy assessment metrics to monitor during the learning process.
  3. Freeze the encoder. I have chosen not to train the encoder to speed up the learning process. Training the encoder may allow for improved model performance; however, this could also lead to overfitting, especially given the large number of trainable parameters for the SegFormer backbone. If you are curious, feel free to train the model with an unfrozen encoder.
  4. Define an optimizer. We will use AdamW with the default learning rate.
  5. Define the number of training epochs, in this case 50, and the location at which to save the logs and final model.
criterion = smp.losses.DiceLoss(mode="multiclass", classes = 5, from_logits=True, smooth=1e-8)
acc = tm.Accuracy(task="multiclass", average="micro", num_classes=5).to(device)
f1 = tm.F1Score(task="multiclass", average="macro", num_classes=5).to(device)
kappa = tm.CohenKappa(task="multiclass", average = "micro", num_classes=5).to(device)
#https://github.com/qubvel/segmentation_models.pytorch/issues/79 
def freeze_encoder(model):
    for child in model.encoder.children():
        for param in child.parameters():
            param.requires_grad = False
    return

def unfreeze(model):
    for child in model.children():
        for param in child.parameters():
            param.requires_grad = True
    return
freeze_encoder(model)
optimizer = torch.optim.AdamW(model.parameters())
epochs = 25
saveFolder = "C:/myFiles/work/dl/landcoverai_segformer/"

Before training the model, I also print a summary of the model using TorchInfo. You can see that the encoder is now very different from a CNN architecture since it is now made up of the components of SegFormer. Also, you can see that many of the parameters are non trainable since a chose to freeze them.

torchinfo.summary(model, (8, 3, 512, 512))
===============================================================================================
Layer (type:depth-idx)                        Output Shape              Param #
===============================================================================================
Unet                                          [8, 5, 512, 512]          --
├─MixVisionTransformerEncoder: 1-1            [8, 3, 512, 512]          --
│    └─OverlapPatchEmbed: 2-1                 [8, 16384, 64]            --
│    │    └─Conv2d: 3-1                       [8, 64, 128, 128]         (9,472)
│    │    └─LayerNorm: 3-2                    [8, 16384, 64]            (128)
│    └─ModuleList: 2-2                        --                        --
│    │    └─Block: 3-3                        [8, 16384, 64]            (314,880)
│    │    └─Block: 3-4                        [8, 16384, 64]            (314,880)
│    │    └─Block: 3-5                        [8, 16384, 64]            (314,880)
│    └─LayerNorm: 2-3                         [8, 16384, 64]            (128)
│    └─OverlapPatchEmbed: 2-4                 [8, 4096, 128]            --
│    │    └─Conv2d: 3-6                       [8, 128, 64, 64]          (73,856)
│    │    └─LayerNorm: 3-7                    [8, 4096, 128]            (256)
│    └─ModuleList: 2-5                        --                        --
│    │    └─Block: 3-8                        [8, 4096, 128]            (465,920)
│    │    └─Block: 3-9                        [8, 4096, 128]            (465,920)
│    │    └─Block: 3-10                       [8, 4096, 128]            (465,920)
│    │    └─Block: 3-11                       [8, 4096, 128]            (465,920)
│    └─LayerNorm: 2-6                         [8, 4096, 128]            (256)
│    └─OverlapPatchEmbed: 2-7                 [8, 1024, 320]            --
│    │    └─Conv2d: 3-12                      [8, 320, 32, 32]          (368,960)
│    │    └─LayerNorm: 3-13                   [8, 1024, 320]            (640)
│    └─ModuleList: 2-8                        --                        --
│    │    └─Block: 3-14                       [8, 1024, 320]            (1,656,320)
│    │    └─Block: 3-15                       [8, 1024, 320]            (1,656,320)
│    │    └─Block: 3-16                       [8, 1024, 320]            (1,656,320)
│    │    └─Block: 3-17                       [8, 1024, 320]            (1,656,320)
│    │    └─Block: 3-18                       [8, 1024, 320]            (1,656,320)
│    │    └─Block: 3-19                       [8, 1024, 320]            (1,656,320)
│    └─LayerNorm: 2-9                         [8, 1024, 320]            (640)
│    └─OverlapPatchEmbed: 2-10                [8, 256, 512]             --
│    │    └─Conv2d: 3-20                      [8, 512, 16, 16]          (1,475,072)
│    │    └─LayerNorm: 3-21                   [8, 256, 512]             (1,024)
│    └─ModuleList: 2-11                       --                        --
│    │    └─Block: 3-22                       [8, 256, 512]             (3,172,864)
│    │    └─Block: 3-23                       [8, 256, 512]             (3,172,864)
│    │    └─Block: 3-24                       [8, 256, 512]             (3,172,864)
│    └─LayerNorm: 2-12                        [8, 256, 512]             (1,024)
├─UnetDecoder: 1-2                            [8, 16, 512, 512]         --
│    └─Identity: 2-13                         [8, 512, 16, 16]          --
│    └─ModuleList: 2-14                       --                        --
│    │    └─DecoderBlock: 3-25                [8, 256, 32, 32]          2,507,776
│    │    └─DecoderBlock: 3-26                [8, 128, 64, 64]          590,336
│    │    └─DecoderBlock: 3-27                [8, 64, 128, 128]         147,712
│    │    └─DecoderBlock: 3-28                [8, 32, 256, 256]         27,776
│    │    └─DecoderBlock: 3-29                [8, 16, 512, 512]         6,976
├─SegmentationHead: 1-3                       [8, 5, 512, 512]          --
│    └─Conv2d: 2-15                           [8, 5, 512, 512]          725
│    └─Identity: 2-16                         [8, 5, 512, 512]          --
│    └─Activation: 2-17                       [8, 5, 512, 512]          --
│    │    └─Identity: 3-30                    [8, 5, 512, 512]          --
===============================================================================================
Total params: 27,477,589
Trainable params: 3,281,301
Non-trainable params: 24,196,288
Total mult-adds (G): 110.78
===============================================================================================
Input size (MB): 25.17
Forward/backward pass size (MB): 9101.64
Params size (MB): 109.91
Estimated Total Size (MB): 9236.72
===============================================================================================
===============================================================================================
Layer (type:depth-idx)                        Output Shape              Param #
===============================================================================================
Unet                                          [8, 5, 512, 512]          --
├─MixVisionTransformerEncoder: 1-1            [8, 3, 512, 512]          --
│    └─OverlapPatchEmbed: 2-1                 [8, 16384, 64]            --
│    │    └─Conv2d: 3-1                       [8, 64, 128, 128]         (9,472)
│    │    └─LayerNorm: 3-2                    [8, 16384, 64]            (128)
│    └─ModuleList: 2-2                        --                        --
│    │    └─Block: 3-3                        [8, 16384, 64]            (314,880)
│    │    └─Block: 3-4                        [8, 16384, 64]            (314,880)
│    │    └─Block: 3-5                        [8, 16384, 64]            (314,880)
│    └─LayerNorm: 2-3                         [8, 16384, 64]            (128)
│    └─OverlapPatchEmbed: 2-4                 [8, 4096, 128]            --
│    │    └─Conv2d: 3-6                       [8, 128, 64, 64]          (73,856)
│    │    └─LayerNorm: 3-7                    [8, 4096, 128]            (256)
│    └─ModuleList: 2-5                        --                        --
│    │    └─Block: 3-8                        [8, 4096, 128]            (465,920)
│    │    └─Block: 3-9                        [8, 4096, 128]            (465,920)
│    │    └─Block: 3-10                       [8, 4096, 128]            (465,920)
│    │    └─Block: 3-11                       [8, 4096, 128]            (465,920)
│    └─LayerNorm: 2-6                         [8, 4096, 128]            (256)
│    └─OverlapPatchEmbed: 2-7                 [8, 1024, 320]            --
│    │    └─Conv2d: 3-12                      [8, 320, 32, 32]          (368,960)
│    │    └─LayerNorm: 3-13                   [8, 1024, 320]            (640)
│    └─ModuleList: 2-8                        --                        --
│    │    └─Block: 3-14                       [8, 1024, 320]            (1,656,320)
│    │    └─Block: 3-15                       [8, 1024, 320]            (1,656,320)
│    │    └─Block: 3-16                       [8, 1024, 320]            (1,656,320)
│    │    └─Block: 3-17                       [8, 1024, 320]            (1,656,320)
│    │    └─Block: 3-18                       [8, 1024, 320]            (1,656,320)
│    │    └─Block: 3-19                       [8, 1024, 320]            (1,656,320)
│    └─LayerNorm: 2-9                         [8, 1024, 320]            (640)
│    └─OverlapPatchEmbed: 2-10                [8, 256, 512]             --
│    │    └─Conv2d: 3-20                      [8, 512, 16, 16]          (1,475,072)
│    │    └─LayerNorm: 3-21                   [8, 256, 512]             (1,024)
│    └─ModuleList: 2-11                       --                        --
│    │    └─Block: 3-22                       [8, 256, 512]             (3,172,864)
│    │    └─Block: 3-23                       [8, 256, 512]             (3,172,864)
│    │    └─Block: 3-24                       [8, 256, 512]             (3,172,864)
│    └─LayerNorm: 2-12                        [8, 256, 512]             (1,024)
├─UnetDecoder: 1-2                            [8, 16, 512, 512]         --
│    └─Identity: 2-13                         [8, 512, 16, 16]          --
│    └─ModuleList: 2-14                       --                        --
│    │    └─DecoderBlock: 3-25                [8, 256, 32, 32]          2,507,776
│    │    └─DecoderBlock: 3-26                [8, 128, 64, 64]          590,336
│    │    └─DecoderBlock: 3-27                [8, 64, 128, 128]         147,712
│    │    └─DecoderBlock: 3-28                [8, 32, 256, 256]         27,776
│    │    └─DecoderBlock: 3-29                [8, 16, 512, 512]         6,976
├─SegmentationHead: 1-3                       [8, 5, 512, 512]          --
│    └─Conv2d: 2-15                           [8, 5, 512, 512]          725
│    └─Identity: 2-16                         [8, 5, 512, 512]          --
│    └─Activation: 2-17                       [8, 5, 512, 512]          --
│    │    └─Identity: 3-30                    [8, 5, 512, 512]          --
===============================================================================================
Total params: 27,477,589
Trainable params: 3,281,301
Non-trainable params: 24,196,288
Total mult-adds (G): 110.78
===============================================================================================
Input size (MB): 25.17
Forward/backward pass size (MB): 9101.64
Params size (MB): 109.91
Estimated Total Size (MB): 9236.72
===============================================================================================

We are now ready to train the model using our standard training loop. A few notes:

  1. I am training for a total of 50 epochs with a mini-batch size of 8 chips.
  2. The validation set will be predicted at the end of each training epoch in order to assess the model and monitor for overfitting.
  3. I am using the Dice loss and monitoring overall accuracy, class-aggregated macro average F1-score, and the Kappa statistic.
  4. At the end of each training epoch, the model will only be saved to disk if there is an improvement in the F1-score for the validation data.

If you execute the training loop, note that it will take several hours to complete. I have provided a trained model file if you would like to run the following model assessment code without training the model.

eNum = []
t_loss = []
t_acc = []
t_f1 = []
t_kappa = []
v_loss = []
v_acc = []
v_f1 = []
v_kappa = []

f1VMax = 0.0

# Loop over epochs
for epoch in range(1, epochs+1):
    model.train()
    running_loss = 0.0
    # Loop over training batches
    for batch_idx, (inputs, targets) in enumerate(trainDL):
        # Get data and move to device
        inputs, targets = inputs.to(device), targets.to(device)

        # Clear gradients
        optimizer.zero_grad()
        # Predict data
        outputs = model(inputs)
        # Calculate loss
        loss = criterion(outputs, targets)

        # Calculate metrics
        accT = acc(outputs, targets)
        f1T = f1(outputs, targets)
        kappaT = kappa(outputs, targets)
        
        # Backpropagate
        loss.backward()

        # Update parameters
        optimizer.step()

        #Update running with batch results
        running_loss += loss.item()

    # Accumulate loss and metrics at end of training epoch
    epoch_loss = running_loss/len(trainDL)
    accT = acc.compute()
    f1T = f1.compute()
    kappaT = kappa.compute()

    # Print Losses and metrics at end of each training epoch   
    print(f'Epoch: {epoch}, Training Loss: {epoch_loss:.4f}, Training Accuracy: {accT:.4f}, Training F1: {f1T:.4f}, Training Kappa: {kappaT:.4f}')

    # Append results
    eNum.append(epoch)
    t_loss.append(epoch_loss)
    t_acc.append(accT.detach().cpu().numpy())
    t_f1.append(f1T.detach().cpu().numpy())
    t_kappa.append(kappaT.detach().cpu().numpy())

    # Reset metrics
    acc.reset()
    f1.reset()
    kappa.reset()

    # loop over validation batches
    with torch.no_grad():
        #Initialize running validation loss
        running_loss_v = 0.0
        for batch_idx, (inputs, targets) in enumerate(valDL):
            # Get data and move to device
            inputs, targets = inputs.to(device), targets.to(device)

            # Predict data
            outputs = model(inputs)
            # Calculate validation loss
            loss_v = criterion(outputs, targets)
            
            #Update running with batch results
            running_loss_v += loss_v.item()

            # Calculate metrics
            accV = acc(outputs, targets)
            f1V = f1(outputs, targets)
            kappaV = kappa(outputs, targets)
            
    #Accumulate loss and metrics at end of validation epoch
    epoch_loss_v = running_loss_v/len(valDL)
    accV = acc.compute()
    f1V = f1.compute()
    kappaV = kappa.compute()

    # Print validation loss and metrics
    print(f'Validation Loss: {epoch_loss_v:.4f}, Validation Accuracy: {accV:.4f}, Validation F1: {f1V:.4f}, Validation Kappa: {kappaV:.4f}')

    # Append results
    v_loss.append(epoch_loss_v)
    v_acc.append(accV.detach().cpu().numpy())
    v_f1.append(f1V.detach().cpu().numpy())
    v_kappa.append(kappaV.detach().cpu().numpy())

    # Reset metrics
    acc.reset()
    f1.reset()
    kappa.reset()

    # Save model if validation F1-score improves
    f1V2 = f1V.detach().cpu().numpy()
    if f1V2 > f1VMax:
        f1VMax = f1V2
        torch.save(model.state_dict(), saveFolder + 'lcai_SegFormer_model.pt')
        print(f'Model saved for epoch {epoch}.')

SeNum = pd.Series(eNum, name="epoch")
St_loss = pd.Series(t_loss, name="training_loss")
St_acc = pd.Series(t_acc, name="training_accuracy")
St_f1 = pd.Series(t_f1, name="training_f1")
St_kappa = pd.Series(t_kappa, name="training_kappa")
Sv_loss = pd.Series(v_loss, name="val_loss")
Sv_acc = pd.Series(v_acc, name="val_accuracy")
Sv_f1 = pd.Series(v_f1, name="val_f1")
Sv_kappa = pd.Series(v_kappa, name="val_kappa")
resultsDF = pd.concat([SeNum, St_loss, St_acc, St_f1, St_kappa, Sv_loss, Sv_acc, Sv_f1, Sv_kappa], axis=1)

resultsDF.to_csv(saveFolder+"resultsLCAISegFormer.csv")

Model Assessment

The remainder of the code is used to assess the trained model using the withheld testing data. This includes the following:

  1. Read in the training logs and plot the training and validation losses and F1-scores.
  2. Re-instantiate the model and load in the saved parameters.
  3. Define the test DataSet and DataLoader.
  4. Instantiate new assessment metrics with no aggregation so that metrics are obtained separately for each class.
  5. Execute a validation loop and print the obtained assessment metrics.
resultsDF = pd.read_csv("data/models/resultsLCAISegFormer.csv")
plt.rcParams['figure.figsize'] = [10, 10]
firstPlot = resultsDF.plot(x='epoch', y="training_loss")
resultsDF.plot(x='epoch', y="val_loss", ax=firstPlot)
plt.show()

plt.rcParams['figure.figsize'] = [10, 10]
firstPlot = resultsDF.plot(x='epoch', y="training_f1")
resultsDF.plot(x='epoch', y="val_f1", ax=firstPlot)
plt.show()

model = smp.Unet(
    encoder_name=encoder, 
    encoder_weights=encoder_weights, 
    classes=5, 
    activation=activation,
    in_channels=3
).to(device)
model.load_state_dict(torch.load(saveFolder + "lcai_SegFormer_model.pt"))
<All keys matched successfully>
testDS = MultiClassSegDataset(testDF, transform=test_transform)
testDL =  DataLoader(testDS, batch_size=8, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=True, timeout=0,
           worker_init_fn=None)
acc = tm.Accuracy(task="multiclass", num_classes=5).to(device)
f1 = tm.F1Score(task="multiclass", num_classes=5, average='none').to(device)
recall = tm.Recall(task="multiclass", num_classes=5, average='none').to(device)
precision = tm.Precision(task="multiclass", num_classes=5, average='none').to(device)
kappa = tm.CohenKappa(task="multiclass", num_classes=5).to(device)
cm = tm.ConfusionMatrix(task="multiclass", num_classes=5).to(device)
model.eval()
Unet(
  (encoder): MixVisionTransformerEncoder(
    (patch_embed1): OverlapPatchEmbed(
      (proj): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed2): OverlapPatchEmbed(
      (proj): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed3): OverlapPatchEmbed(
      (proj): Conv2d(128, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed4): OverlapPatchEmbed(
      (proj): Conv2d(320, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (block1): ModuleList(
      (0): Block(
        (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=64, out_features=64, bias=True)
          (kv): Linear(in_features=64, out_features=128, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=64, out_features=64, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (sr): Conv2d(64, 64, kernel_size=(8, 8), stride=(8, 8))
          (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=64, out_features=256, bias=True)
          (dwconv): DWConv(
            (dwconv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
          )
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=256, out_features=64, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): Block(
        (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=64, out_features=64, bias=True)
          (kv): Linear(in_features=64, out_features=128, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=64, out_features=64, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (sr): Conv2d(64, 64, kernel_size=(8, 8), stride=(8, 8))
          (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (drop_path): DropPath(drop_prob=0.007)
        (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=64, out_features=256, bias=True)
          (dwconv): DWConv(
            (dwconv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
          )
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=256, out_features=64, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (2): Block(
        (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=64, out_features=64, bias=True)
          (kv): Linear(in_features=64, out_features=128, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=64, out_features=64, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (sr): Conv2d(64, 64, kernel_size=(8, 8), stride=(8, 8))
          (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (drop_path): DropPath(drop_prob=0.013)
        (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=64, out_features=256, bias=True)
          (dwconv): DWConv(
            (dwconv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
          )
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=256, out_features=64, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
    (block2): ModuleList(
      (0): Block(
        (norm1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=128, out_features=128, bias=True)
          (kv): Linear(in_features=128, out_features=256, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=128, out_features=128, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (sr): Conv2d(128, 128, kernel_size=(4, 4), stride=(4, 4))
          (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        )
        (drop_path): DropPath(drop_prob=0.020)
        (norm2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=128, out_features=512, bias=True)
          (dwconv): DWConv(
            (dwconv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512)
          )
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=512, out_features=128, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): Block(
        (norm1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=128, out_features=128, bias=True)
          (kv): Linear(in_features=128, out_features=256, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=128, out_features=128, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (sr): Conv2d(128, 128, kernel_size=(4, 4), stride=(4, 4))
          (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        )
        (drop_path): DropPath(drop_prob=0.027)
        (norm2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=128, out_features=512, bias=True)
          (dwconv): DWConv(
            (dwconv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512)
          )
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=512, out_features=128, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (2): Block(
        (norm1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=128, out_features=128, bias=True)
          (kv): Linear(in_features=128, out_features=256, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=128, out_features=128, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (sr): Conv2d(128, 128, kernel_size=(4, 4), stride=(4, 4))
          (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        )
        (drop_path): DropPath(drop_prob=0.033)
        (norm2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=128, out_features=512, bias=True)
          (dwconv): DWConv(
            (dwconv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512)
          )
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=512, out_features=128, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (3): Block(
        (norm1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=128, out_features=128, bias=True)
          (kv): Linear(in_features=128, out_features=256, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=128, out_features=128, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (sr): Conv2d(128, 128, kernel_size=(4, 4), stride=(4, 4))
          (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        )
        (drop_path): DropPath(drop_prob=0.040)
        (norm2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=128, out_features=512, bias=True)
          (dwconv): DWConv(
            (dwconv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512)
          )
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=512, out_features=128, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (norm2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
    (block3): ModuleList(
      (0): Block(
        (norm1): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=320, out_features=320, bias=True)
          (kv): Linear(in_features=320, out_features=640, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=320, out_features=320, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (sr): Conv2d(320, 320, kernel_size=(2, 2), stride=(2, 2))
          (norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (drop_path): DropPath(drop_prob=0.047)
        (norm2): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=320, out_features=1280, bias=True)
          (dwconv): DWConv(
            (dwconv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1280)
          )
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=1280, out_features=320, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): Block(
        (norm1): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=320, out_features=320, bias=True)
          (kv): Linear(in_features=320, out_features=640, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=320, out_features=320, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (sr): Conv2d(320, 320, kernel_size=(2, 2), stride=(2, 2))
          (norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (drop_path): DropPath(drop_prob=0.053)
        (norm2): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=320, out_features=1280, bias=True)
          (dwconv): DWConv(
            (dwconv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1280)
          )
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=1280, out_features=320, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (2): Block(
        (norm1): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=320, out_features=320, bias=True)
          (kv): Linear(in_features=320, out_features=640, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=320, out_features=320, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (sr): Conv2d(320, 320, kernel_size=(2, 2), stride=(2, 2))
          (norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (drop_path): DropPath(drop_prob=0.060)
        (norm2): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=320, out_features=1280, bias=True)
          (dwconv): DWConv(
            (dwconv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1280)
          )
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=1280, out_features=320, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (3): Block(
        (norm1): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=320, out_features=320, bias=True)
          (kv): Linear(in_features=320, out_features=640, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=320, out_features=320, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (sr): Conv2d(320, 320, kernel_size=(2, 2), stride=(2, 2))
          (norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (drop_path): DropPath(drop_prob=0.067)
        (norm2): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=320, out_features=1280, bias=True)
          (dwconv): DWConv(
            (dwconv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1280)
          )
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=1280, out_features=320, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (4): Block(
        (norm1): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=320, out_features=320, bias=True)
          (kv): Linear(in_features=320, out_features=640, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=320, out_features=320, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (sr): Conv2d(320, 320, kernel_size=(2, 2), stride=(2, 2))
          (norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (drop_path): DropPath(drop_prob=0.073)
        (norm2): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=320, out_features=1280, bias=True)
          (dwconv): DWConv(
            (dwconv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1280)
          )
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=1280, out_features=320, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (5): Block(
        (norm1): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=320, out_features=320, bias=True)
          (kv): Linear(in_features=320, out_features=640, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=320, out_features=320, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (sr): Conv2d(320, 320, kernel_size=(2, 2), stride=(2, 2))
          (norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (drop_path): DropPath(drop_prob=0.080)
        (norm2): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=320, out_features=1280, bias=True)
          (dwconv): DWConv(
            (dwconv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1280)
          )
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=1280, out_features=320, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (norm3): LayerNorm((320,), eps=1e-06, elementwise_affine=True)
    (block4): ModuleList(
      (0): Block(
        (norm1): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=512, out_features=512, bias=True)
          (kv): Linear(in_features=512, out_features=1024, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=512, out_features=512, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): DropPath(drop_prob=0.087)
        (norm2): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (dwconv): DWConv(
            (dwconv): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2048)
          )
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): Block(
        (norm1): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=512, out_features=512, bias=True)
          (kv): Linear(in_features=512, out_features=1024, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=512, out_features=512, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): DropPath(drop_prob=0.093)
        (norm2): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (dwconv): DWConv(
            (dwconv): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2048)
          )
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (2): Block(
        (norm1): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=512, out_features=512, bias=True)
          (kv): Linear(in_features=512, out_features=1024, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=512, out_features=512, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): DropPath(drop_prob=0.100)
        (norm2): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (dwconv): DWConv(
            (dwconv): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2048)
          )
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (norm4): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
  )
  (decoder): UnetDecoder(
    (center): Identity()
    (blocks): ModuleList(
      (0): DecoderBlock(
        (conv1): Conv2dReLU(
          (0): Conv2d(832, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (attention1): Attention(
          (attention): Identity()
        )
        (conv2): Conv2dReLU(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (attention2): Attention(
          (attention): Identity()
        )
      )
      (1): DecoderBlock(
        (conv1): Conv2dReLU(
          (0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (attention1): Attention(
          (attention): Identity()
        )
        (conv2): Conv2dReLU(
          (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (attention2): Attention(
          (attention): Identity()
        )
      )
      (2): DecoderBlock(
        (conv1): Conv2dReLU(
          (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (attention1): Attention(
          (attention): Identity()
        )
        (conv2): Conv2dReLU(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (attention2): Attention(
          (attention): Identity()
        )
      )
      (3): DecoderBlock(
        (conv1): Conv2dReLU(
          (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (attention1): Attention(
          (attention): Identity()
        )
        (conv2): Conv2dReLU(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (attention2): Attention(
          (attention): Identity()
        )
      )
      (4): DecoderBlock(
        (conv1): Conv2dReLU(
          (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (attention1): Attention(
          (attention): Identity()
        )
        (conv2): Conv2dReLU(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (attention2): Attention(
          (attention): Identity()
        )
      )
    )
  )
  (segmentation_head): SegmentationHead(
    (0): Conv2d(16, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Identity()
    (2): Activation(
      (activation): Identity()
    )
  )
)
with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testDL):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        #loss_v = criterion(outputs, targets)
        accV = acc(outputs, targets)
        f1V = f1(outputs, targets)
        rV = recall(outputs, targets)
        pV = precision(outputs, targets)
        kappaV = kappa(outputs, targets)
        cmV = cm(outputs, targets)
accV = acc.compute()
f1V = f1.compute()
rV = recall.compute()
pV = precision.compute()
kappaV = kappa.compute()
cmV = cm.compute()
acc.reset()
f1.reset()
recall.reset()
precision.reset()
kappa.reset()
cm.reset()
print(accV)
tensor(0.9335, device='cuda:0')
print(f1V)
tensor([0.9452, 0.8386, 0.9251, 0.9367, 0.7420], device='cuda:0')
print(rV)
tensor([0.9586, 0.8680, 0.9087, 0.9161, 0.6983], device='cuda:0')
print(pV)
tensor([0.9323, 0.8110, 0.9421, 0.9582, 0.7915], device='cuda:0')
print(kappaV)
tensor(0.8784, device='cuda:0')
print(cmV)
tensor([[229862911,    723735,   7160869,    812370,   1240456],
        [   462547,   3434455,     12798,      4232,     42522],
        [ 12874525,      3325, 130652883,    156093,     91325],
        [  1720055,       369,    318527,  22342779,      6392],
        [  1646586,     72957,    544417,       959,   5242313]],
       device='cuda:0')

Concluding Remarks

As this module demonstrates, it is possible to train a semantic segmentation architecture using the Segmentation Models library that uses a SegFormer-based encoder with very minimal changes to the code. This highlights one of the benefits of using the Segmentation Models package. Since transformer-based architectures may be able to capture spatial context over broader spatial extents, you may want to experiment with using a SegFormer backbone as opposed to one based on a CNN architecture. However, training an architecture that includes transformer components may require more training data in order to see improvement relative to a CNN-based architecture.