Get startedGet started for free

Defining models with LightningModule

1. Defining models with LightningModule

Welcome back! Now that we've seen LightningModule and Trainer applied to a simple model, let's deepen our understanding so we can build more complex models.

2. LightningModule in focus

In the previous video, we saw how LightningModule encapsulates our model architecture, but also organizes training logic into a single, manageable unit. In this video, we zoom in on how to use LightningModule to build a robust classification model. Think of it as the blueprint that brings order and clarity to your deep learning projects.

3. Defining the init method

At the heart of every LightningModule is the __init__ method. When we define a model class by inheriting from LightningModule, calling super().__init__() ensures proper initialization of the parent class. Specifically, it ensures essential internal components provided by PyTorch Lightning, such as automated handling of training loops, logging, and checkpointing are available by default in our new class. As illustrated in the code, after initializing the parent class, we define layers like input, activation, and output explicitly, allowing them to be reused effectively during training. This structured approach makes our model logic modular and easy to maintain, even as our projects grow in complexity.

4. Implementing the forward method

Now, let's implement the forward method-the core function that defines how data flows through our network. In this method, the input is sequentially processed through each layer: first through a linear transformation, then an activation, and finally through another linear layer to produce the output. Here we summarize these steps, ensuring that the model is fully prepared for both inference and backpropagation.

5. Example: classifying hand written digits

Now let's ground the forward method in a real task. Suppose we want to classify handwritten digits in the classic dataset version. We first load the images as tensors and wrap them in PyTorch?DataLoaders. Next, we instantiate the ClassificationModel, passing the flattened image size, a 128-unit hidden layer, and ten output classes. Finally, Lightning's `Trainer` orchestrates the training loop with a single line-calling `fit`-so our model learns to distinguish digits with minimal boilerplate.

6. Integrating the model with classification tasks

Now let's see how all the pieces come together for a classification task. This entire flow is housed in our LightningModule, which keeps everything neatly organized. Here data goes from the input layer, through our ReLU activation, and into the final output. We typically feed this output into a softmax or cross-entropy loss function for accurate predictions. Finally, we can pair our module with the Lightning Trainer for a training loop, so we can focus on designing and refining our model without being hindered by repetitive boilerplate code.

7. Let's practice!

Now let's practice what we've learn here in the following exercises.