Improving Model Performance
Improving Model Performance
Introduction
Before moving on to discuss convolutional neural networks, I wanted to take some time to discuss methods that are commonly used to attempt to improve model performance. In its current state, deep learning is science, art, and tradecraft. As a result, analysts learn with time and experience what techniques may help them achieve more accurate and generalizable models. Unfortunately, due to the complexity of these models and the methods used to train them, it is not possible to assess all possible options to potentially improve models. Instead, individuals must build intuition as to what methods might be most appropriate for their specific use case or application domain. With that said, there are some commonly used techniques that are often explored, and those are the focus of this section. Note that I am saving a discussion of some specific methods for later modules.
Methods that are often explored to potentially improve model performance include the following.
- Collecting more or better training data. Deep learning models tend to be data hungry, partially due to the large number of parameters that must be estimated and the issue of overfitting. As a result, adding training samples is an often first suggestion to improve model performance. Unfortunately, it is not always practical or possible to collect more data. Further, the number of samples that are needed to successfully train a model without overfitting is not generally clearly defined. This is often case specific.
- It is possible to initialize models using weights/parameters learned from other datasets. For example, large datasets, such as ImageNet and Common Objects in Context (COCO), can be used to train models. The learned parameters can be used as a starting point to then apply the model to a new problem or dataset. This can be accomplished by initializing the model using these learned parameters but still training the entire architecture, essentially using these parameters as a starting point as opposed to using a random initialization. When learned parameters are used to initialized the model but all components are still trainable or updateable, it has been documented that using different learning rates for different components of the architecture can be useful. We will discuss this in a later model. Another option when applying pre-trained parameters is to only train part of the architecture, essentially leaving part of the architecture frozen or untrainable. The assumption is that the learned parameters may be adequate enough for the current task or that they capture information that is meaningful for the new task to which they are being applied. The process of using parameters that were learned from another, and often much larger, dataset is called transfer learning. We will discuss transfer learning in later modules, so will not cover it in this section.
- In order to reduce overfitting and improve model generalization, it is common to augment the input data. Using images as an example, random flips, rotations, sharpening, blurring, and/or adjustments to saturation, contrast, and/or brightness can be applied. This is termed image augmentation, and the idea is that creating more variability in the training data can result in a more varied representation and thus discourage overfitting to the training data. We will explore data augmentations in the context of convolutional neural networks for scene labeling and semantic segmentation tasks in later modules.
- Data imbalance can have a negative impact on the learning process. Classes that make up a larger proportion of the training dataset will have a larger impact on the learning process and the calculated loss; as a result, the model will tend to focus more on these classes as opposed to less abundant classes. Practically, this can result in the model doing a poorer job at predicting the less abundant classes, overpredicting the abundance of the more prominent classes, and/or underpredicting the abundance of the less prominent classes. There are several methods to combat overfitting including using class weightings in the loss metric calculations, using loss metrics that are less impacted by class imbalance, and/or augmenting the training data so that classes are represented with more even abundance. We will discuss these methods in this section.
- The model architecture can also be altered to potentially improve model performance. Common modifications include integrating batch normalization and/or dropouts, increasing the complexity of the network so that it can model more complex patterns in the data, and using different activation functions. We will discuss some of these methods in this section. One issue here is that increasing the complexity of the network can result in overfitting and reduced generalization, especially if the training dataset is small. Thus, increasing the size and complexity of the model architecture may not have the desired effect. A lot of convolutional neural network architectures have been specifically designed and proposed to deal with the issue of capturing complex patterns while not overfitting to the training dataset. We will discuss some of these architectures and logic behind them in later sections.
- The optimization algorithm used can have a large impact on the modeling process, and there are many alternatives to traditional mini-batch gradient descent that are derived from and build on this method. I generally use the AdamW optimizer as a default. However, people have varying opinions about the best optimizer for a specific use case. They all have pros and cons.
- Regardless of the optimizer used, the selected learning rate can have a large impact on model performance. My go-to first attempts to potentially improve model performance are (1) to change or augment the loss function being used and/or (2) to adjust the learning rate. There are other optimizer-specific settings that can impact the learning process as well. One means to select the learning rate specifically is to make use of a learning rate finder. We will discuss learning rate finders in this section.
- It is possible to augment the learning rate, or other parameter settings, during the learning process. In PyTorch this is generally implemented using schedulers. We will explore schedulers in this section.
- There are other training loop modifications that may also be helpful. In this section, we will explore gradient accumulation.
Preparation
As normal, I begin with imports. I also set the device variable. Here, I am specifically making use of a GPU to potentially speed up computation.
import numpy as np
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
'figure.figsize'] = [20, 20]
plt.rcParams[
import os
import random
from sklearn.utils.class_weight import compute_class_weight
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import torchmetrics as tm
= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device print(device)
cuda:0
I will use the same data and problem that we have been exploring: the differentiation of land cover types using band means derived from the EuroSatAllBands dataset. There is nothing new in the code blocks below. Here is a quick review of the data preparation process.
- Read in the data tables of band means
- Calculate the means and standard deviations from the training dataset
- Define a DataSet subclass
- Instantiate the datasets and normalize the band means
- Instantiate the DataLoaders
= "C:/myFiles/work/dl/eurosat/EuroSATallBands/"
folder = pd.read_csv(folder + "train_aggregated.csv")
trainAgg = pd.read_csv(folder + "test_aggregated.csv")
testAgg = pd.read_csv(folder + "val_aggregated.csv") valAgg
trainAgg.head()
Unnamed: 0 class code ... NIR_Narrow swir1 swir2
0 0 AnnualCrop 0 ... 674.346680 1461.915527 3229.570801
1 1 AnnualCrop 0 ... 758.629395 1189.456543 2714.708008
2 2 AnnualCrop 0 ... 761.717041 1621.122070 3900.247314
3 3 AnnualCrop 0 ... 638.532471 1469.149170 3986.078125
4 4 AnnualCrop 0 ... 948.018066 1656.648682 3609.164551
[5 rows x 13 columns]
= np.array(trainAgg.iloc[:,3:].mean(axis=0)).flatten()
trainAggMns print(trainAggMns)
[1116.07645724 1035.31213955 937.87729602 1184.29159721 1971.09456594
2334.38145703 2262.04021603 724.95266414 1103.83350066 2552.01787743]
= np.array(trainAgg.iloc[:,3:].std(axis=0)).flatten()
trainAggSDs print(trainAggSDs)
[ 255.54328897 312.29876238 479.85770488 500.93361557 779.66575883
976.84613356 970.08515215 384.49257895 670.88688372 1120.89990655]
class EuroSat(Dataset):
def __init__(self, df, bndMns, bndSDs):
super().__init__
self.df = df
self.bndMns = bndMns
self.bndSDs = bndSDs
def __getitem__(self, idx):
= [self.df.iloc[idx, 3:]]
bands = [self.df.iloc[idx, 2]]
label = np.array(bands)
bands = (bands-self.bndMns)/self.bndSDs
bands = np.array(label)
label = bands.astype('float32')
bands = torch.from_numpy(bands).squeeze().float()
bands = torch.from_numpy(label).squeeze().float()
label = label.long()
label return bands, label
def __len__(self):
return len(self.df)
= EuroSat(trainAgg, trainAggMns, trainAggSDs)
trainDS = EuroSat(testAgg, trainAggMns, trainAggSDs)
testDS = EuroSat(valAgg, trainAggMns, trainAggSDs) valDS
= torch.utils.data.DataLoader(trainDS, batch_size=256, shuffle=True, sampler=None,
trainDL =0, pin_memory=False, drop_last=False)
num_workers= torch.utils.data.DataLoader(testDS, batch_size=256, shuffle=True, sampler=None,
testDL =0, pin_memory=False, drop_last=False)
num_workers= torch.utils.data.DataLoader(valDS, batch_size=256, shuffle=True, sampler=None,
valDL =0, pin_memory=False, drop_last=False) num_workers
Dropouts
Dropouts are a regularization method in which certain neurons are dropped out or their associated parameters are not updated within the training loop during passes over training mini-batches. The idea is that not updating all of the weights/parameters during each weight update can result in reduced overfitting and improved generalization. How many neurons will be dropped in each weight update is controlled by the p parameter. In the example model architecture below, I am applying dropouts between each fully connected layer, which will cause a random 30% of the neurons to not be updated with each weight/parameter update. Note that I do not use dropouts after the final fully connected layer. Dropouts are implemented in PyTorch using nn.Dropout().
There is currently some debate as to whether dropouts are still necessary. With the advent and wide adoption of batch normalization, dropouts have generally been used less frequently. Analysts are using batch normalization as a means to potentially reduce overfitting and improve generalization as opposed to dropouts, and including both dropouts and batch normalization may be unnecessary. Again, analysts have varying opinions. I generally prefer to use batch normalization as opposed to dropouts. However, you may want to experiment with dropouts as a potential means to improve model performance. Note that dropout methods are also available for convolutional layers.
class myFCNDrop(nn.Module):
def __init__(self, inSize, hiddenSizes, outSize):
super().__init__()
self.inSize = inSize
self.hiddenSize = hiddenSizes
self.outSize = outSize
self.theNetwork = nn.Sequential(
0]),
nn.Linear(inSize, hiddenSizes[=0.3),
nn.Dropout(p=True),
nn.ReLU(inplace0], hiddenSizes[1]),
nn.Linear(hiddenSizes[=.3),
nn.Dropout(p=True),
nn.ReLU(inplace1], outSize)
nn.Linear(hiddenSizes[
)
def forward(self, x):
= self.theNetwork(x)
x return x
Leaky ReLU
One known issue with the rectified linear unit (ReLU) activation function is the issue of “dying ReLU”. Remember that this activation function works by converting negative activations to 0 and maintaining all positive activation values. If all activations are negative, then they all will be converted to 0. This will result in a gradient of 0 and not allow for weight/parameter updates. To alleviate this issue, the leaky ReLU was introduced. Instead of simply converting all negative activations to 0, negative values are maintained but with a reduced magnitude by multiplying them by a slope term that is less than 1.
In PyTorch, ReLU can be replaced with leaky ReLU by simply changing nn.ReLU() to nn.LeakyReLU. You can also set the negative slope parameter, which is 0.01 by default.
I have generally found using leaky ReLU to be useful and worth considering. It is a very simple change, and it does not add trainable parameters to the model. There are also versions of ReLU that introduce trainable parameters; however, we will not discuss those here.
class myFCNLeaky(nn.Module):
def __init__(self, inSize, hiddenSizes, outSize):
super().__init__()
self.inSize = inSize
self.hiddenSize = hiddenSizes
self.outSize = outSize
self.theNetwork = nn.Sequential(
0]),
nn.Linear(inSize, hiddenSizes[0]),
nn.BatchNorm1d(hiddenSizes[=0.1, inplace=True),
nn.LeakyReLU(negative_slope0], hiddenSizes[1]),
nn.Linear(hiddenSizes[1]),
nn.BatchNorm1d(hiddenSizes[=0.1, inplace=True),
nn.LeakyReLU(negative_slope1], outSize)
nn.Linear(hiddenSizes[
)
def forward(self, x):
= self.theNetwork(x)
x return x
Class Weighting
One means to potentially reduce the impact of class imbalance in the dataset is to apply weights to each class when performing the loss calculations. Specifically, less abundant classes should have a higher weight in the loss calculation in comparison to more abundant classes. To accomplish this, it is common to calculate weights as the inverse of their relative abundance in the dataset.
Fortunately, scikit-learn provides the compute_class_weight() function for accomplishing this. This requires to first extract the class labels into a numpy array. Setting the class_weight parameter to ‘balanced’ will yield weights that equalize the impact of each class on the weight updates. Once the weights are defined, they can be written to a torch tensor and moved to the device.
Some loss metrics allow for the incorporation of class weights. Here, I have provided an example of defining a cross entropy loss with class weights using the weight parameter in the nn.CrossEntropyLoss() implementation. To see if weights can be used and how they are integrated using other loss functions, please consult the associated documentation.
It is also possible to incorporate class weights into some assessment metric calculations. Although this won’t inform the parameter updates, it can be useful for selecting a model or determining whether the model is improving.
= np.array(trainAgg.iloc[:,3]) npLbls
npLbls
array([1301.21630859, 1042.02490234, 1411.48461914, ..., 963.27075195,
1153.65087891, 1105.24951172])
= compute_class_weight(class_weight='balanced', classes=np.unique(npLbls), y=npLbls)
weights = torch.tensor(weights,dtype=torch.float).to(device) weightsT
weightsT
tensor([1.0025, 1.0025, 1.0025, ..., 1.0025, 1.0025, 1.0025], device='cuda:0')
= nn.CrossEntropyLoss(weight=weightsT).to(device) criterion
Alternative Loss Metric
Another option is to use a loss metric that is less impacted by class imbalance, such as the Dice or Tversky loss. Further, the focal Tversky loss is generally more robust to data imbalance than cross entropy loss while also allowing for defining the relative impacts of false positive and false negative errors and allowing for adjusting the impact of difficult-to-classify training samples. Another option is to combine loss functions to create a combination loss, such as using cross entropy loss + Dice or cross entropy loss + Tversky loss.
Generally, I have found using an alternative loss metric to be worth considering and beneficial, especially when the data are imbalanced. I generally have seen notable improvements when using Dice as opposed to cross entropy or Dice + cross entropy as opposed to just cross entropy. The Tverksy loss is useful when you want to have more control over the relative impact of false positive and false negative errors.
Below, I have provided example implementations of focal Tversky loss for both binary and multiclass classifications. Custom losses can be defined by subclassing nn.Module. If the alpha term is set to 0.5, the result will be a Dice loss.
class FocalTverskyLossBinary(nn.Module):
def __init__(self, alpha=0.5, beta=0.5, gamma=1.0, epsilon=1e-7):
super(FocalTverskyLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.epsilon = epsilon
def forward(self, logits, targets):
"""
Args:
logits: A tensor of shape (N, 1, H, W) where N is the batch size,
1 is the number of classes (binary classification), and H, W are height and width of the image.
targets: A tensor of shape (N, H, W) where each value is 0 or 1.
"""
= torch.sigmoid(logits)
probs = targets.unsqueeze(1).float() # Change shape to (N, 1, H, W)
targets
= torch.sum(probs * targets, dim=(0, 2, 3))
true_pos = torch.sum(targets * (1 - probs), dim=(0, 2, 3))
false_neg = torch.sum((1 - targets) * probs, dim=(0, 2, 3))
false_pos
= (true_pos + self.epsilon) / (true_pos + self.alpha * false_neg + self.beta * false_pos + self.epsilon)
tversky_index = torch.pow((1 - tversky_index), self.gamma)
focal_tversky_loss
return focal_tversky_loss.mean()
class FocalTverskyLossMultiClass(nn.Module):
def __init__(self, alpha=0.5, beta=0.5, gamma=1.0, epsilon=1e-7):
super(FocalTverskyLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.epsilon = epsilon
def forward(self, logits, targets):
= logits.shape[1]
num_classes = F.one_hot(targets, num_classes=num_classes).permute(0, 3, 1, 2).float()
targets_one_hot
= torch.softmax(logits, dim=1)
probs
= torch.sum(probs * targets_one_hot, dim=(0, 2, 3))
true_pos = torch.sum(targets_one_hot * (1 - probs), dim=(0, 2, 3))
false_neg = torch.sum((1 - targets_one_hot) * probs, dim=(0, 2, 3))
false_pos
= (true_pos + self.epsilon) / (true_pos + self.alpha * false_neg + self.beta * false_pos + self.epsilon)
tversky_index = torch.pow((1 - tversky_index), self.gamma)
focal_tversky_loss
return focal_tversky_loss.mean()
Class Balancing Dataset
Instead of augmenting the loss function to deal with class imbalance, you may choose to augment the data. One means to do this is to incorporate a sampler into the DataLoader, which is used to select which samples are provided in the training loop. Specifically, classes that are less abundant will be sampled more often than those that are less abundant.
To accomplish this, I first create a new column in the training DataFrame called “weight”. This is initially filled with the class numeric codes. I then use the pandas replace() method to replace the class codes with their desired weights, as defined using the compute_class_weight() function from scikit-learn. The weights are then converted to a tensor and moved to the device. The sampler is then defined using the WeightedRandomSampler() function from PyTorch. Lastly, I redefine the DataLoader for the training set and specify the sampler. Note that you cannot use shuffling along with a sampler.
This method can be especially useful for scene labeling tasks. However, it is more difficult to apply when the goal is pixel-level, or semantic, segmentation. This is because the sampling unit is now the pixel as opposed to the entire image chip. In these cases, I have found it to be necessary to consider a more manual augmentation of the training data. For example, you may choose to not use any or only a subset of image chips that only include pixels mapped to the background class. We will discuss this in more detail in the modules associated with semantic segmentation.
from torch.utils.data import WeightedRandomSampler
'weight'] = trainAgg['code'] trainAgg[
= compute_class_weight(class_weight='balanced', classes=np.unique(npLbls), y=npLbls) wweights
= trainAgg.replace({'weight': {0: weights[0],
trainaAgg 1: weights[1],
2: weights[2],
3: weights[3],
4: weights[4],
5: weights[5],
6: weights[6],
7: weights[7],
8: weights[8],
9: weights[9]}})
=torch.tensor(trainAgg["weight"], dtype=torch.float).to(device) weightsT
'weight'], axis=1, inplace=True) trainAgg.drop([
= WeightedRandomSampler(
sampler =len(trainAgg), replacement=True
weightsT, num_samples )
= torch.utils.data.DataLoader(trainDS, batch_size=256, shuffle=False, sampler=sampler,
trainDL =0, pin_memory=False, drop_last=False) num_workers
Learning Rate Finder
Tools have been developed to help you select an appropriate learning rate or range of learning rates. Details of how this is conducted are described in the following paper by Leslie Smith:
Smith, L.N., 2017, March. Cyclical learning rates for training neural networks. In 2017 IEEE winter conference on applications of computer vision (WACV) (pp. 464-472). IEEE.
Unfortunately, this method is not natively available in PyTorch. Instead, I will make use of the implementation provided here: https://github.com/davidtvs/pytorch-lr-finder. You will need to install this package to execute the example code.
In the provided example, I am testing learning rates between 0.0001 and 1 for up to 100 iterations. The result is a graph that depicts the loss at different learning rates. The best learning rate is defined as the one with the steepest negatve slope.
To learn more about how this process works, please consult the paper referenced above and/or the associated lecture module. Generally, the learning rate starts at a very low rate and is gradually increased to a larger rate. Initially, the loss will not decrease because the learning rate is too low to allow for meaningful parameter updates. As the learning rates increases, the loss should begin to decrease. However, eventually the learning rates will get too high and result in unstable learning and gradient explosion. The range of learning rates that are allowing for meaning parameter updates, as estimated by reduction in the loss metric, without unstable learning or a gradient explosion, are suggested to be optimal.
#!pip install torch-lr-finder
from torch_lr_finder import LRFinder
class myFCNLeaky(nn.Module):
def __init__(self, inSize, hiddenSizes, outSize):
super().__init__()
self.inSize = inSize
self.hiddenSize = hiddenSizes
self.outSize = outSize
self.theNetwork = nn.Sequential(
0]),
nn.Linear(inSize, hiddenSizes[0]),
nn.BatchNorm1d(hiddenSizes[=0.1, inplace=True),
nn.LeakyReLU(negative_slope0], hiddenSizes[1]),
nn.Linear(hiddenSizes[1]),
nn.BatchNorm1d(hiddenSizes[=0.1, inplace=True),
nn.LeakyReLU(negative_slope1], outSize)
nn.Linear(hiddenSizes[
)
def forward(self, x):
= self.theNetwork(x)
x return x
= myFCNLeaky(10, [256, 256], 10).to(device) model
= nn.CrossEntropyLoss()
criterion = torch.optim.AdamW(model.parameters())
optimizer = LRFinder(model, optimizer, criterion, device="cuda")
lr_finder =valDL, start_lr=0.0001, end_lr=1, num_iter=100, step_mode="linear") lr_finder.range_test(trainDL, val_loader
Stopping early, the loss has diverged
Learning rate search finished. See the graph with {finder_name}.plot()
0%| | 0/100 [00:00<?, ?it/s]
1%|1 | 1/100 [00:01<02:47, 1.69s/it]
2%|2 | 2/100 [00:02<02:12, 1.36s/it]
3%|3 | 3/100 [00:03<01:59, 1.23s/it]
4%|4 | 4/100 [00:04<01:52, 1.17s/it]
5%|5 | 5/100 [00:06<01:48, 1.14s/it]
6%|6 | 6/100 [00:07<01:45, 1.12s/it]
7%|7 | 7/100 [00:08<01:42, 1.11s/it]
8%|8 | 8/100 [00:09<01:40, 1.10s/it]
9%|9 | 9/100 [00:10<01:39, 1.09s/it]
10%|# | 10/100 [00:11<01:38, 1.09s/it]
11%|#1 | 11/100 [00:12<01:36, 1.09s/it]
12%|#2 | 12/100 [00:13<01:35, 1.09s/it]
13%|#3 | 13/100 [00:14<01:34, 1.08s/it]
14%|#4 | 14/100 [00:15<01:33, 1.08s/it]
15%|#5 | 15/100 [00:16<01:32, 1.09s/it]
16%|#6 | 16/100 [00:17<01:31, 1.09s/it]
17%|#7 | 17/100 [00:19<01:30, 1.09s/it]
18%|#8 | 18/100 [00:20<01:30, 1.10s/it]
19%|#9 | 19/100 [00:21<01:29, 1.10s/it]
20%|## | 20/100 [00:22<01:28, 1.10s/it]
21%|##1 | 21/100 [00:23<01:26, 1.10s/it]
22%|##2 | 22/100 [00:24<01:26, 1.10s/it]
23%|##3 | 23/100 [00:25<01:24, 1.10s/it]
24%|##4 | 24/100 [00:26<01:23, 1.10s/it]
25%|##5 | 25/100 [00:27<01:22, 1.10s/it]
26%|##6 | 26/100 [00:28<01:20, 1.09s/it]
27%|##7 | 27/100 [00:30<01:19, 1.09s/it]
28%|##8 | 28/100 [00:31<01:18, 1.10s/it]
29%|##9 | 29/100 [00:32<01:17, 1.10s/it]
30%|### | 30/100 [00:33<01:16, 1.10s/it]
31%|###1 | 31/100 [00:34<01:15, 1.10s/it]
32%|###2 | 32/100 [00:35<01:14, 1.10s/it]
33%|###3 | 33/100 [00:36<01:13, 1.10s/it]
34%|###4 | 34/100 [00:37<01:12, 1.10s/it]
35%|###5 | 35/100 [00:38<01:12, 1.11s/it]
36%|###6 | 36/100 [00:40<01:10, 1.10s/it]
37%|###7 | 37/100 [00:41<01:09, 1.10s/it]
38%|###8 | 38/100 [00:42<01:08, 1.11s/it]
39%|###9 | 39/100 [00:43<01:07, 1.10s/it]
40%|#### | 40/100 [00:44<01:06, 1.10s/it]
41%|####1 | 41/100 [00:45<01:04, 1.10s/it]
42%|####2 | 42/100 [00:46<01:03, 1.10s/it]
43%|####3 | 43/100 [00:47<01:02, 1.10s/it]
44%|####4 | 44/100 [00:48<01:01, 1.10s/it]
45%|####5 | 45/100 [00:49<01:00, 1.10s/it]
46%|####6 | 46/100 [00:51<01:00, 1.12s/it]
47%|####6 | 47/100 [00:52<00:58, 1.11s/it]
48%|####8 | 48/100 [00:53<00:57, 1.11s/it]
49%|####9 | 49/100 [00:54<00:56, 1.11s/it]
50%|##### | 50/100 [00:55<00:55, 1.10s/it]
51%|#####1 | 51/100 [00:56<00:54, 1.10s/it]
52%|#####2 | 52/100 [00:57<00:52, 1.10s/it]
53%|#####3 | 53/100 [00:58<00:52, 1.11s/it]
54%|#####4 | 54/100 [00:59<00:50, 1.11s/it]
55%|#####5 | 55/100 [01:00<00:49, 1.10s/it]
56%|#####6 | 56/100 [01:02<00:48, 1.09s/it]
57%|#####6 | 57/100 [01:03<00:47, 1.09s/it]
58%|#####8 | 58/100 [01:04<00:45, 1.09s/it]
59%|#####8 | 59/100 [01:05<00:44, 1.09s/it]
60%|###### | 60/100 [01:06<00:43, 1.08s/it]
61%|######1 | 61/100 [01:07<00:42, 1.08s/it]
62%|######2 | 62/100 [01:08<00:41, 1.08s/it]
63%|######3 | 63/100 [01:09<00:40, 1.08s/it]
64%|######4 | 64/100 [01:10<00:39, 1.09s/it]
65%|######5 | 65/100 [01:11<00:37, 1.08s/it]
66%|######6 | 66/100 [01:12<00:36, 1.08s/it]
67%|######7 | 67/100 [01:13<00:35, 1.08s/it]
68%|######8 | 68/100 [01:15<00:34, 1.08s/it]
69%|######9 | 69/100 [01:16<00:33, 1.08s/it]
70%|####### | 70/100 [01:17<00:32, 1.09s/it]
71%|#######1 | 71/100 [01:18<00:31, 1.08s/it]
72%|#######2 | 72/100 [01:19<00:30, 1.08s/it]
73%|#######3 | 73/100 [01:20<00:29, 1.09s/it]
74%|#######4 | 74/100 [01:21<00:28, 1.09s/it]
75%|#######5 | 75/100 [01:22<00:27, 1.08s/it]
76%|#######6 | 76/100 [01:23<00:25, 1.08s/it]
77%|#######7 | 77/100 [01:24<00:24, 1.08s/it]
78%|#######8 | 78/100 [01:25<00:23, 1.08s/it]
79%|#######9 | 79/100 [01:26<00:22, 1.08s/it]
80%|######## | 80/100 [01:28<00:21, 1.08s/it]
81%|########1 | 81/100 [01:29<00:20, 1.08s/it]
82%|########2 | 82/100 [01:30<00:19, 1.09s/it]
83%|########2 | 83/100 [01:31<00:18, 1.09s/it]
84%|########4 | 84/100 [01:32<00:17, 1.09s/it]
85%|########5 | 85/100 [01:33<00:16, 1.08s/it]
86%|########6 | 86/100 [01:34<00:15, 1.09s/it]
87%|########7 | 87/100 [01:35<00:14, 1.09s/it]
88%|########8 | 88/100 [01:36<00:13, 1.09s/it]
89%|########9 | 89/100 [01:37<00:11, 1.09s/it]
90%|######### | 90/100 [01:38<00:10, 1.09s/it]
91%|#########1| 91/100 [01:40<00:09, 1.09s/it]
91%|#########1| 91/100 [01:41<00:10, 1.11s/it]
=False) lr_finder.plot(log_lr
LR suggestion: steepest gradient
Suggested LR: 2.93E-01
(<AxesSubplot: >, 0.293)
lr_finder.reset()
Learning Rate Scheduling
Another technique that has proved to be useful is augmenting the learning rate throughout the training process. There are lots of different ways to do this. One of the simplest is to use a larger learning rate early in the training then reduce the learning rate after a defined number of epochs. The idea here is that the model can be more coarsely trained initially then fine-tuned with a lower learning rate during later epochs.
One method that has been shown to be especially effective is to implement a one cycle learning rate policy in which the learning rate starts low, raises to a maximum learning rate, then descends back to a low learning rate as a single cycle during the training process. This method was proposed by Leslie Smith in the same paper referenced above. This method has been implemented in PyTorch as a scheduler. Schedulers allow for modifying the training loop in some way by calling a step. Note that a variety of schedulers are available for different use cases. If you are interested in learning more about schedulers, please consult the PyTorch documentation.
To implement a one cycle learning rate in PyTorch, you can use the OneCycleLR() function. This allows for specifying the maximum learning rate, number of epochs, and number of steps per epoch (i.e., mini-batches). This information allows the algorithm to determine how to alter the learning rate over the entire learning process.
To actually augment the learning rate using this scheduler, you must call scheduler.step() in the training loop. Since this needs to happen after processing each mini-batch, this should be done within the mini-batch for loop for the training mini-batches. In the example below, I have called it after optimizer.step(). So, this is actually pretty easy to implement. It only requires (1) defining and instantiating an instance of the scheduler and (2) performing a step within each training mini-batch iteration in the training loop.
There are many other methods for augmenting learning rates. I recommend this post on Kaggle that demonstrates a variety of learning rate schedulers using PyTorch: https://www.kaggle.com/code/isbhargav/guide-to-pytorch-learning-rate-scheduling.
= torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) optimizer
= torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, epochs=50, steps_per_epoch=len(trainDL)) scheduler
= nn.CrossEntropyLoss().to(device) criterion
#Define assessment metrics
= tm.Accuracy(task="multiclass", num_classes=10).to(device)
acc = tm.F1Score(task="multiclass", num_classes=10).to(device)
f1 = tm.CohenKappa(task="multiclass", num_classes=10).to(device) kappa
= 50 epochs
= "C:/myFiles/work/dl/eurosat_fcnn_models/" saveFolder
= []
eNum = []
t_loss = []
t_acc = []
t_f1 = []
t_kappa = []
v_loss = []
v_acc = []
v_f1 = []
v_kappa
#Loop over epochs
for epoch in range(1, epochs+1):
#Put model in training mode
model.train()#initialize running training loss
= 0.0
running_loss #Loop over batches
for batch_idx, (inputs, targets) in enumerate(trainDL):
#Get data and move to device
= inputs.to(device), targets.to(device)
inputs, targets
#Clear gradients
optimizer.zero_grad()#Predict data
= model(inputs)
outputs #Calculate loss
= criterion(outputs, targets)
loss
#Update running with batch results
+= loss.item()
running_loss
#Calculate metrics
= acc(outputs, targets)
accT = f1(outputs, targets)
f1T = kappa(outputs, targets)
kappaT
#Backpropagate
loss.backward()
# update parameters
optimizer.step()
scheduler.step()
# Accumulate loss and metrics at end of training epoch
= running_loss/len(trainDL)
epoch_loss = acc.compute()
accT = f1.compute()
f1T = kappa.compute()
kappaT
# Print Losses and metrics at end of each Epoch
print(f'Epoch: {epoch}, Training Loss: {epoch_loss:.4f}, Training Accuracy: {accT:.4f}, Training F1: {f1T:.4f}, Training Kappa: {kappaT:.4f}')
#Append results
eNum.append(epoch_loss)
t_loss.append(loss.item())
t_acc.append(accT.detach().cpu().numpy())
t_f1.append(f1T.detach().cpu().numpy())
t_kappa.append(kappaT.detach().cpu().numpy())
#Reset metrics
acc.reset()
f1.reset()
kappa.reset()
#Set model in evaluation model
eval()
model.# loop over validation batches
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(valDL):
#Get data and move to device
= inputs.to(device), targets.to(device)
inputs, targets
#Predict data
= model(inputs)
outputs #Calculate validation loss
= criterion(outputs, targets)
loss_v
#Update running with batch results
+= loss_v.item()
running_loss_v
#Accumulate loss and metrics at end of validation epoch
= running_loss_v/len(valDL)
epoch_loss_v = acc(outputs, targets)
accV = f1(outputs, targets)
f1V = kappa(outputs, targets)
kappaV
#Accumulate metrics at end of epoch
= acc.compute()
accV = f1.compute()
f1V = kappa.compute()
kappaV
#Print validation loss and metrics
print(f'Validation Loss: {epoch_loss_v:.4f}, Validation Accuracy: {accV:.4f}, Validation F1: {f1V:.4f}, Validation Kappa: {kappaV:.4f}')
#Append results
v_loss.append(epoch_loss_v)
v_acc.append(accV.detach().cpu().numpy())
v_f1.append(f1V.detach().cpu().numpy())
v_kappa.append(kappaV.detach().cpu().numpy())
#Reset metrics
acc.reset()
f1.reset()
kappa.reset()
+ 'eurosat_model_' + str(epoch) + '.pt')
torch.save(model.state_dict(), saveFolder print(f'Model saved for epoch {epoch}.')
Gradient Accumulation
Very small mini-batch sizes can result in noisy or suboptimal parameter updates. However, it is not always possible to use a large mini-batch size. For example, if you are using large image chips, a complex architecture, and/or have limited GPU VRAM, you may not be able to use a large mini-batch size. To get around this issue, you can implement gradient accumulation. The idea here is that the weights/parameters are not updated after each training mini-batch is processed. Instead, the gradients are allowed to accumulate over a given number of mini-batches prior to performing the optimization.
In the example below, I am allowing the gradients to accumulate for 4 mini-batches before applying the optimization or weight/parameter updates. This involves (1) defining a variable, in this case accum_iter, that specifies over how many mini-batches to accumulate the losses over, (2) dividing the loss by this variable for normalization, (3) defining a condition with an if statement so that gradients will only be updated after every 4 mini-batches or for the final training mini-batch, and (5) only applying the optimizer and clearing the gradients if one of these conditions are met. In my example below, this condition is applied inside of the for loop that iterates over the mini-batches and after the loss has been calculated and the backpropagation has been performed. Also, I am not clearing the gradients at the beginning of each training mini-batch. Instead, this occurs within the conditions since accumulating the gradients cannot occur if they are cleared. So, the gradients are only cleared after an optimization step has been performed.
In this specific example, gradient accumulation is not necessary since the mini-batch size is large. This is just an example. Also, remember that large mini-batch sizes can be problematic. So, use gradient accumulation with care.
= []
eNum = []
t_loss = []
t_acc = []
t_f1 = []
t_kappa = []
v_loss = []
v_acc = []
v_f1 = []
v_kappa
#Loop over epochs
for epoch in range(1, epochs+1):
#Put model in training mode
model.train()#initialize running training loss
= 0.0
running_loss #Loop over batches
for batch_idx, (inputs, targets) in enumerate(trainDL):
#Get data and move to device
= inputs.to(device), targets.to(device)
inputs, targets
#Clear gradients
optimizer.zero_grad()#Predict data
= model(inputs)
outputs #Calculate loss
= criterion(outputs, targets)
loss
#Update running with batch results
+= loss.item()
running_loss
#Calculate metrics
= acc(outputs, targets)
accT = f1(outputs, targets)
f1T = kappa(outputs, targets)
kappaT
#Backpropagate
loss.backward()
# Implement gradient accumulation
if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(trainDL)):
optimizer.step()
optimizer.zero_grad()
# Accumulate loss and metrics at end of training epoch
= running_loss/len(trainDL)
epoch_loss = acc.compute()
accT = f1.compute()
f1T = kappa.compute()
kappaT
# Print Losses and metrics at end of each Epoch
print(f'Epoch: {epoch}, Training Loss: {epoch_loss:.4f}, Training Accuracy: {accT:.4f}, Training F1: {f1T:.4f}, Training Kappa: {kappaT:.4f}')
#Append results
eNum.append(epoch_loss)
t_loss.append(loss.item())
t_acc.append(accT.detach().cpu().numpy())
t_f1.append(f1T.detach().cpu().numpy())
t_kappa.append(kappaT.detach().cpu().numpy())
#Reset metrics
acc.reset()
f1.reset()
kappa.reset()
#Set model in evaluation model
eval()
model.# loop over validation batches
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(valDL):
#Get data and move to device
= inputs.to(device), targets.to(device)
inputs, targets
#Predict data
= model(inputs)
outputs #Calculate validation loss
= criterion(outputs, targets)
loss_v
#Update running with batch results
+= loss_v.item()
running_loss_v
#Accumulate loss and metrics at end of validation epoch
= running_loss_v/len(valDL)
epoch_loss_v = acc(outputs, targets)
accV = f1(outputs, targets)
f1V = kappa(outputs, targets)
kappaV
#Accumulate metrics at end of epoch
= acc.compute()
accV = f1.compute()
f1V = kappa.compute()
kappaV
#Print validation loss and metrics
print(f'Validation Loss: {epoch_loss_v:.4f}, Validation Accuracy: {accV:.4f}, Validation F1: {f1V:.4f}, Validation Kappa: {kappaV:.4f}')
#Append results
v_loss.append(epoch_loss_v)
v_acc.append(accV.detach().cpu().numpy())
v_f1.append(f1V.detach().cpu().numpy())
v_kappa.append(kappaV.detach().cpu().numpy())
#Reset metrics
acc.reset()
f1.reset()
kappa.reset()
+ 'eurosat_model_' + str(epoch) + '.pt')
torch.save(model.state_dict(), saveFolder print(f'Model saved for epoch {epoch}.')
Using Multiple GPUs
Training on multiple GPUs is an effective means to speed up the training process. So, if you are fortunate enough to have access to a computer with multiple CUDA-enabled GPUs or a GPU cluster, I highly recommend augmenting your code to make use of this hardware. The code below shows how to get the count of available GPUs.
= torch.cuda.device_count() devCnt
There are a few methods that can be used to train your model over multiple GPUs. I tend to use nn.DataParallel(). This allows for partitioning the training mini-batches into multiple smaller mini-batches that are then processed on separate GPUs in parallel.
Implementing this method is very easy. You simply need to wrap your model in this function, as demonstrated below, then move it to the device. The training loop can then be implemented as normal.
= nn.DataParallel(model) model
model.to(DEVICE)
"PATH TO SAVED PARAMETERS") model.module.load_state_dict(
Enhancing Reproducibility
Lastly, you may want to be able to obtain the same results when running the training process multiple times. Obtaining reproducibility is complex with PyTorch, especially when using GPU-based computation. The following paper provides some guidance on obtaining reproducibility within PyTorch:
Alahmari, S.S., Goldgof, D.B., Mouton, P.R. and Hall, L.O., 2020. Challenges for the repeatability of deep learning models. IEEE Access, 8, pp.211860-211868.
The code below demonstrates creating a function that can set multiple random seeds and other settings that should allow for reproducible results with PyTorch.
def set_seed2(seed=2019):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)=True
torch.backends.cudnn.deterministic=False torch.backends.cudnn.benchmark
Concluding Remarks
The goal of this module was to introduce some common methods that can be used to potentially improve the performance of the final model and/or to combat some common issues, such as class imbalance. If you are struggling with your training process, it might be worth investigating these techniques. It is also a good idea to explore the internet, blogs, and help sites for other suggestions that might be appropriate for your specific use case. Also, deep learning techniques are still developing, so new methods may be proposed that are not presented here. With that said, I have found the methods that I discussed here to be especially useful.