1. Učit se
  2. /
  3. Kurzy
  4. /
  5. Intermediate Deep Learning with PyTorch

Connected

cvičení

Trénovací smyčka RNN

Je čas natrénovat model pro předpověď spotřeby elektrické energie!

Použiješ LSTM síť, kterou jsi definoval/a dříve a která byla vytvořena a přiřazena do net, stejně jako dataloader_train, který jsi sestavil/a předtím. Budeš také potřebovat torch.nn, který už je naimportovaný jako nn.

V tomto cvičení natrénuješ model po dobu tří epoch, abys ověřil/a, že trénink probíhá správně. Pojďme na to!

Pokyny

100 XP
  • Nastav funkci Mean Squared Error loss a přiřaď ji do criterion.
  • Změň tvar seqs na (batch size, sequence length, num features), což je v našem případě (32, 96, 1), a výsledek znovu přiřaď do seqs.
  • Předej seqs modelu a získej jeho outputs.
  • Na základě dříve vypočtených hodnot vypočítej ztrátu a přiřaď ji do loss.