Perfecting the forward method
After setting up layers in the __init__
method, the forward method dictates how data flows through them. In PyTorch Lightning, this separation keeps your code clean and easy to maintain. You've already seen how to structure the constructor-now it's time to focus on the forward pass, ensuring your classification logic is clear and optimized for training. Here, the layers in __init__
are already defined for you, so you can concentrate purely on the forward flow.
The lightning.pytorch
and torch.nn
have already been imported as pl
and nn
.
This exercise is part of the course
Scalable AI Models with PyTorch Lightning
Exercise instructions
- Implement the
forward
method insideClassifierModel
. - Apply a ReLU activation after the hidden layer.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
class ClassifierModel(pl.LightningModule):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.hidden = nn.Linear(input_dim, hidden_dim)
self.output = nn.Linear(hidden_dim, output_dim)
# Define forward method
def ____(self, ____):
# Complete the forward pass
x = self.hidden(x)
x = ____(x)
x = self.output(x)
return x