Get startedGet started for free

Enhancing training with Lightning callbacks

1. Enhancing training with Lightning callbacks

Welcome back! In this video, we'll automate tasks like saving the best model and stopping training at the right moment to streamline our workflow.

2. What are callbacks?

We all want our best model without wasting time or overfitting – Lightning callbacks will do this for us automatically. So, what are callbacks? In PyTorch Lightning, they're special functions triggered at key stages during training. They let us add custom actions—like saving the best model or stopping early—while keeping code clean and modular. Callbacks boost flexibility and control without adding complexity. They can also track metrics and stop training if performance stalls.

3. What are callbacks?

Here, we define a class MyPrintingCallback that inherits from Lightning Callback. By implementing the on_train_start method, we're instructing the callback to print a message at the start of training. Similarly, we define an on_train_end method. This shows how callbacks let us inject actions like logging, adjusting learning rates, or stopping early at different stages of training.

4. Lightning ModelCheckpoint callback

One key callback is ModelCheckpoint, which saves models automatically when certain conditions are met. After importing ModelCheckpoint, we create an instance and configure it. We set monitor to 'val_loss' to track validation loss. The dirpath specifies where to save checkpoint files. We define a filename template with placeholders for epoch and validation loss. Lightning fills these at runtime, making it easy to spot the best model. By setting save_top_k to 1, we keep only the best model, saving storage space. Finally, mode is set to 'min' to minimize validation loss. This setup saves the best model as training progresses.

5. Lightning EarlyStopping callback

Another powerful callback is EarlyStopping, which halts training when improvements stop, preventing overfitting and saving time. We import EarlyStopping, setting monitor to 'val_loss'. Patience is set to 3, so training stops if validation loss doesn't improve for three epochs. This allows the model time to improve without wasting resources. We can adjust this value based on experimental results. Mode is 'min' since we're minimizing loss. Using EarlyStopping prevents overfitting and saves computational resources.

6. Customizing and using lightning callbacks

Suppose our primary goal is to maximize validation accuracy instead of minimizing loss. We can adjust the monitor parameter to 'val_accuracy'. We might set save_top_k to 2 in ModelCheckpoint to save the top two models, giving us options when selecting the best one for deployment. Since we're maximizing accuracy, we change mode to 'max'. In EarlyStopping, we might increase patience to 5, allowing more epochs for potential improvements in accuracy. These small tweaks help tailor training to different goals while managing resources efficiently. To use these callbacks, we need to pass them to our Lightning Trainer, so that they'll trigger at the right time during training.

7. Let's practice!

In the upcoming exercises, we'll get hands-on experience implementing these callbacks in a training loop.