1. Apprendre
  2. /
  3. Cours
  4. /
  5. Середній рівень Deep Learning з PyTorch

Connected

Exercice

Цикл тренування RNN

Час тренувати модель прогнозування споживання електроенергії!

Ви використаєте мережу LSTM, яку визначили раніше. Її екземпляр уже створено та збережено в net, так само як і dataloader_train, який ви побудували до цього. Також вам знадобиться torch.nn, який уже імпортовано як nn.

У цій вправі ви потренуєте модель лише три епохи, щоб переконатися, що навчання відбувається як очікується. Почнімо!

Instructions

100 XP
  • Налаштуйте функцію втрат середньоквадратичної помилки (Mean Squared Error) і збережіть її в criterion.
  • Змініть форму seqs до (batch size, sequence length, num features), у нашому випадку це (32, 96, 1), і знову присвойте результат seqs.
  • Передайте seqs у модель, щоб отримати її outputs.
  • На основі раніше обчислених величин розрахуйте втрати та збережіть їх у loss.