Get startedGet started for free

Batch Normalization

As a final improvement to the model architecture, let's add the batch normalization layer after each of the two linear layers. The batch norm trick tends to accelerate training convergence and protects the model from vanishing and exploding gradients issues.

Both torch.nn and torch.nn.init have already been imported for you as nn and init, respectively. Once you implement the change in the model architecture, be ready to answer a short question on how batch normalization works!

This exercise is part of the course

Intermediate Deep Learning with PyTorch

View Course

Hands-on interactive exercise

Have a go at this exercise by completing this sample code.

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(9, 16)
        # Add two batch normalization layers
        ____ = ____
        self.fc2 = nn.Linear(16, 8)
        ____ = ____
        self.fc3 = nn.Linear(8, 1)
        
        init.kaiming_uniform_(self.fc1.weight)
        init.kaiming_uniform_(self.fc2.weight)
        init.kaiming_uniform_(self.fc3.weight, nonlinearity="sigmoid")
Edit and Run Code