CommencerCommencer gratuitement

Diviser les données avec LightningDataModule

Vous devrez suivre la méthode de l'setup dans un LightningDataModule. Un partitionnement adéquat des ensembles de données garantit que le modèle est entraîné sur un sous-ensemble et validé sur un autre, ce qui évite le surapprentissage.

Le module d'extension « dataset » a déjà été pré-importé.

Cet exercice fait partie du cours

Modèles d'IA évolutifs avec PyTorch Lightning

Afficher le cours

Instructions

  • Importez l'random_split pour diviser l'ensemble de données en deux parties : formation et validation.
  • Divisez l'ensemble de données en deux parties : formation (80 %) et validation (20 %) à l'aide de l'random_split.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

# Import libraries 
import lightning.pytorch as pl
from torch.utils.data import ____

class SplitDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.train_data = None
        self.val_data = None
    def setup(self, stage=None):
        # Split the dataset into training (80%) and validation (20%)
        self.____, self.____ = random_split(dataset, [____, ____])
Modifier et exécuter le code