1. Nauka
  2. /
  3. Kursy
  4. /
  5. Głębokie uczenie z PyTorch – poziom średnio zaawansowany

Connected

ćwiczenie

Pętla treningowa RNN

Czas wytrenować model prognozowania zużycia energii elektrycznej!

Skorzystasz z sieci LSTM zdefiniowanej wcześniej, która została zainicjalizowana i przypisana do zmiennej net, podobnie jak zbudowany wcześniej dataloader_train. Będziesz też potrzebować modułu torch.nn, który został już zaimportowany jako nn.

W tym ćwiczeniu wytrenuj model przez tylko trzy epoki, żeby upewnić się, że trening przebiega zgodnie z oczekiwaniami. Do dzieła!

Instrukcje

100 XP
  • Skonfiguruj funkcję straty Mean Squared Error i przypisz ją do zmiennej criterion.
  • Przekształć seqs do kształtu (batch size, sequence length, num features), czyli w naszym przypadku (32, 96, 1), i przypisz wynik z powrotem do seqs.
  • Przekaż seqs do modelu, aby uzyskać jego outputs.
  • Na podstawie wcześniej obliczonych wartości oblicz stratę i przypisz ją do zmiennej loss.