Get startedGet started for free

Fighting overfitting

1. Fighting overfitting

Previously, we learned how to detect overfitting by looking at the training and validation losses. In this video, we'll discover a few ways we can fight overfitting.

2. Reasons for overfitting

Recall that overfitting happens when the model does not generalize to unseen data. If we do not train the model correctly, it will start to memorize the training data, which leads to good performance on the training set but poor performance on the validation set. Several factors can lead to overfitting: a small dataset, a model with too much capacity, or large values of weights.

3. Fighting overfitting

To counter overfitting, we can reduce the model size or add a new type of layer called dropout. We can also use weight decay to force the parameters to remain small. We can get more data or use data augmentation. Let's explore these strategies.

4. "Regularization" using a dropout layer

A common way to fight overfitting is to add dropout layers to our neural network. Dropout is a "regularization" technique which randomly deactivates a fraction of neurons during training, preventing the model from becoming too dependent on specific features. Dropout layers are typically added after activation functions. The p argument determines the probability of a neuron being set to zero. In this example, 50% of the neurons are dropped. Dropout behaves differently during training and evaluation. During training, it randomly deactivates neurons, while during evaluation, it is disabled, ensuring all neurons are active for stable predictions. To switch between these modes, we use model.train() and model.eval().

5. Regularization with weight decay

The next strategy to reduce overfitting we will discover is weight decay, another form of regularization. In PyTorch, weight decay is added to the optimizer using the weight_decay parameter, typically set to a small value, for example, 0.001. This parameter adds a penalty to the loss function, encouraging smaller weights and helping the model generalize better. During backpropagation, this penalty is subtracted from the gradient, preventing excessive weight growth. The higher we set the weight decay, the stronger the regularization, making overfitting less likely.

6. Data augmentation

Collecting more data can be expensive, but researchers have found a way to expand datasets artificially using data augmentation. Data augmentation is commonly applied to image data, which can be rotated and scaled, so that different views of the same face become available as "new" data points. While we won't discuss how to augment data here, it remains a valuable method for combating overfitting when additional data isn't available.

7. Let's practice!

Let's practice!