Splitting data with LightningDataModule
You will complete the setup method in a LightningDataModule. Proper dataset partitioning ensures that the model is trained on one subset and validated on another, preventing overfitting.
The dataset has already been pre-imported.
Latihan ini adalah bagian dari kursus
Scalable AI Models with PyTorch Lightning
Petunjuk latihan
- Import
random_splitto split the dataset into training and validation. - Split the dataset into training (80%) and validation (20%) using
random_split.
Latihan interaktif praktis
Cobalah latihan ini dengan menyelesaikan kode contoh berikut.
# Import libraries
import lightning.pytorch as pl
from torch.utils.data import ____
class SplitDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.train_data = None
self.val_data = None
def setup(self, stage=None):
# Split the dataset into training (80%) and validation (20%)
self.____, self.____ = random_split(dataset, [____, ____])