Batch Normalization
As a final improvement to the model architecture, let's add the batch normalization layer after each of the two linear layers. The batch norm trick tends to accelerate training convergence and protects the model from vanishing and exploding gradients issues.
Both torch.nn
and torch.nn.init
have already been imported for you as nn
and init
, respectively. Once you implement the change in the model architecture, be ready to answer a short question on how batch normalization works!
This exercise is part of the course
Intermediate Deep Learning with PyTorch
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(9, 16)
# Add two batch normalization layers
____ = ____
self.fc2 = nn.Linear(16, 8)
____ = ____
self.fc3 = nn.Linear(8, 1)
init.kaiming_uniform_(self.fc1.weight)
init.kaiming_uniform_(self.fc2.weight)
init.kaiming_uniform_(self.fc3.weight, nonlinearity="sigmoid")