CommencerCommencez gratuitement

Diviser les données avec LightningDataModule

Vous devrez suivre la méthode de l'setup dans un LightningDataModule. Un partitionnement adéquat des ensembles de données garantit que le modèle est entraîné sur un sous-ensemble et validé sur un autre, ce qui évite le surapprentissage.

Le module d'extension « dataset » a déjà été pré-importé.

Cet exercice fait partie du cours

<cours>Modèles d'IA évolutifs avec PyTorch Lightning</cours>
Voir le cours

Instructions de l’exercice

  • Importez l'random_split pour diviser l'ensemble de données en deux parties : formation et validation.
  • Divisez l'ensemble de données en deux parties : formation (80 %) et validation (20 %) à l'aide de l'random_split.

Exercice interactif pratique

Essayez cet exercice en complétant ce code d’exemple.

# Import libraries 
import lightning.pytorch as pl
from torch.utils.data import ____

class SplitDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.train_data = None
        self.val_data = None
    def setup(self, stage=None):
        # Split the dataset into training (80%) and validation (20%)
        self.____, self.____ = random_split(dataset, [____, ____])
Modifier et exécuter le code