Diviser les données avec LightningDataModule
Vous devrez suivre la méthode de l'setup
dans un LightningDataModule
. Un partitionnement adéquat des ensembles de données garantit que le modèle est entraîné sur un sous-ensemble et validé sur un autre, ce qui évite le surapprentissage.
Le module d'extension « dataset
» a déjà été pré-importé.
Cet exercice fait partie du cours
Modèles d'IA évolutifs avec PyTorch Lightning
Instructions
- Importez l'
random_split
pour diviser l'ensemble de données en deux parties : formation et validation. - Divisez l'ensemble de données en deux parties : formation (80 %) et validation (20 %) à l'aide de l'
random_split
.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
# Import libraries
import lightning.pytorch as pl
from torch.utils.data import ____
class SplitDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.train_data = None
self.val_data = None
def setup(self, stage=None):
# Split the dataset into training (80%) and validation (20%)
self.____, self.____ = random_split(dataset, [____, ____])