1. Nauka
  2. /
  3. Kursy
  4. /
  5. Wprowadzenie do uczenia głębokiego z PyTorch

Connected

ćwiczenie

Pisanie pętli treningowej

W scikit-learn pętla treningowa jest ukryta w metodzie .fit(), natomiast w PyTorch trzeba ją zdefiniować ręcznie. Daje to większą elastyczność, ale wymaga własnej implementacji.

W tym ćwiczeniu napiszesz pętlę do trenowania modelu przewidującego wynagrodzenia.

Do wizualizacji przykładowych predykcji udostępniono funkcję show_results().

Dostępne są następujące importy: pandas jako pd, torch, torch.nn jako nn, torch.optim jako optim, a także DataLoader i TensorDataset z torch.utils.data.

Utworzone zostały następujące zmienne: num_epochs – liczba epok (ustawiona na 5); dataloader – obiekt DataLoader; model – sieć neuronowa; criterion – funkcja straty nn.MSELoss(); optimizer – optymalizator SGD.

Instrukcje 1/3

undefined XP
    1
    2
    3
  • Napisz pętlę for iterującą po dataloader; zagnieźdź ją wewnątrz pętli for iterującej po zakresie równym liczbie epok.
  • Wyzeruj gradienty optymalizatora.