Membuat DataLoader pelatihan
Sekarang setelah himpunan data kita dibagi, kita perlu mendefinisikan data loader untuk menyediakan batch data selama pelatihan. DataLoader memuat data ke memori secara efisien dan memungkinkan pengacakan untuk generalisasi yang lebih baik. Dalam latihan ini, Anda akan melengkapi metode train_dataloader.
Latihan ini adalah bagian dari kursus
Model AI yang Dapat Diskalakan dengan PyTorch Lightning
Petunjuk latihan
- Impor
DataLoader. - Kembalikan
DataLoaderyang memuatself.train_data, dengan mengaktifkan pengacakan untuk meningkatkan generalisasi.
Latihan interaktif praktis
Cobalah latihan ini dengan menyelesaikan kode contoh berikut.
# Import libraries
from torch.utils.data import ____
import lightning.pytorch as pl
class LoaderDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.train_data = None
self.val_data = None
def setup(self, stage=None):
self.train_data, self.val_data = random_split(dataset, [80, 20])
def train_dataloader(self):
# Complete DataLoader
return ____(____, batch_size=16, shuffle=____)