1. Học hỏi
  2. /
  3. Khoa Học
  4. /
  5. Deep Learning cho Ảnh với PyTorch

Connected

Bài tập

Vòng lặp huấn luyện

Cuối cùng thì mọi công sức bạn bỏ ra để định nghĩa kiến trúc mô hình và hàm loss cũng đã đến lúc phát huy: đến giờ huấn luyện rồi! Nhiệm vụ của bạn là triển khai và chạy vòng lặp huấn luyện GAN. Lưu ý: một lệnh break được đặt sau batch dữ liệu đầu tiên để tránh thời gian chạy quá lâu.

Hai bộ tối ưu disc_opt và gen_opt đã được khởi tạo bằng bộ tối ưu Adam(). Các hàm tính loss mà bạn đã định nghĩa trước đó, gen_loss() và disc_loss(), đã sẵn sàng để dùng. Một dataloader cũng đã được chuẩn bị cho bạn.

Nhắc lại rằng:

  • Tham số của disc_loss() là: gen, disc, real, cur_batch_size, z_dim.
  • Tham số của gen_loss() là: gen, disc, cur_batch_size, z_dim.

Hướng dẫn

100 XP
  • Tính loss của discriminator dùng disc_loss() bằng cách truyền lần lượt generator, discriminator, mẫu ảnh thật, kích thước batch hiện tại, và kích thước nhiễu là 16, rồi gán kết quả vào d_loss.
  • Tính gradient từ d_loss.
  • Tính loss của generator dùng gen_loss() bằng cách truyền lần lượt generator, discriminator, kích thước batch hiện tại, và kích thước nhiễu là 16, rồi gán kết quả vào g_loss.
  • Tính gradient từ g_loss.