Splitting data with LightningDataModule
You will complete the setup method in a LightningDataModule. Proper dataset partitioning ensures that the model is trained on one subset and validated on another, preventing overfitting.
The dataset has already been pre-imported.
Bu egzersiz
Scalable AI Models with PyTorch Lightning
kursunun bir parçasıdırEgzersiz talimatları
- Import
random_splitto split the dataset into training and validation. - Split the dataset into training (80%) and validation (20%) using
random_split.
Uygulamalı interaktif egzersiz
Bu örnek kodu tamamlayarak bu egzersizi bitirin.
# Import libraries
import lightning.pytorch as pl
from torch.utils.data import ____
class SplitDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.train_data = None
self.val_data = None
def setup(self, stage=None):
# Split the dataset into training (80%) and validation (20%)
self.____, self.____ = random_split(dataset, [____, ____])