Splitting data with LightningDataModule
You will complete the setup
method in a LightningDataModule
. Proper dataset partitioning ensures that the model is trained on one subset and validated on another, preventing overfitting.
The dataset
has already been pre-imported.
This exercise is part of the course
Scalable AI Models with PyTorch Lightning
Exercise instructions
- Import
random_split
to split the dataset into training and validation. - Split the dataset into training (80%) and validation (20%) using
random_split
.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
# Import libraries
import lightning.pytorch as pl
from torch.utils.data import ____
class SplitDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.train_data = None
self.val_data = None
def setup(self, stage=None):
# Split the dataset into training (80%) and validation (20%)
self.____, self.____ = random_split(dataset, [____, ____])