IniziaInizia gratis

Suddividere i dati con LightningDataModule

Completerai il metodo setup in un LightningDataModule. Una suddivisione corretta dell'insieme di dati garantisce che il modello venga addestrato su un sottoinsieme e validato su un altro, evitando l'overfitting.

Il dataset è già stato pre-importato.

Questo esercizio fa parte del corso

Modelli di AI scalabili con PyTorch Lightning

Visualizza il corso

Istruzioni dell'esercizio

  • Importa random_split per suddividere l'insieme di dati in training e validation.
  • Suddividi l'insieme di dati in training (80%) e validation (20%) usando random_split.

Esercizio pratico interattivo

Prova a risolvere questo esercizio completando il codice di esempio.

# Import libraries 
import lightning.pytorch as pl
from torch.utils.data import ____

class SplitDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.train_data = None
        self.val_data = None
    def setup(self, stage=None):
        # Split the dataset into training (80%) and validation (20%)
        self.____, self.____ = random_split(dataset, [____, ____])
Modifica ed esegui il codice