Suddividere i dati con LightningDataModule
Completerai il metodo setup in un LightningDataModule. Una suddivisione corretta dell'insieme di dati garantisce che il modello venga addestrato su un sottoinsieme e validato su un altro, evitando l'overfitting.
Il dataset è già stato pre-importato.
Questo esercizio fa parte del corso
Modelli di AI scalabili con PyTorch Lightning
Istruzioni dell'esercizio
- Importa
random_splitper suddividere l'insieme di dati in training e validation. - Suddividi l'insieme di dati in training (80%) e validation (20%) usando
random_split.
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
# Import libraries
import lightning.pytorch as pl
from torch.utils.data import ____
class SplitDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.train_data = None
self.val_data = None
def setup(self, stage=None):
# Split the dataset into training (80%) and validation (20%)
self.____, self.____ = random_split(dataset, [____, ____])