1. Học hỏi
  2. /
  3. Khoa Học
  4. /
  5. Deep Learning cho Ảnh với PyTorch

Connected

Bài tập

Hàm mất mát của Discriminator

Đến lúc định nghĩa hàm mất mát cho discriminator. Hãy nhớ rằng nhiệm vụ của discriminator là phân loại ảnh là thật hay giả. Vì vậy, discriminator sẽ chịu mất mát nếu nó phân loại đầu ra của generator là thật (nhãn 1) hoặc phân loại ảnh thật là giả (nhãn 0).

Hãy định nghĩa hàm disc_loss() để tính mất mát của discriminator. Hàm nhận năm đối số:

  • gen: mô hình generator
  • disc: mô hình discriminator
  • real: một mẫu ảnh thật từ dữ liệu huấn luyện
  • num_images: số lượng ảnh trong một batch
  • z_dim: kích thước của nhiễu ngẫu nhiên đầu vào

Hướng dẫn

100 XP
  • Dùng discriminator để phân loại ảnh fake và gán dự đoán vào disc_pred_fake.
  • Tính thành phần mất mát cho ảnh giả bằng cách gọi criterion trên dự đoán của discriminator cho ảnh giả và một tensor toàn số 0 có cùng shape.
  • Dùng discriminator để phân loại ảnh real và gán dự đoán vào disc_pred_real.
  • Tính thành phần mất mát cho ảnh thật bằng cách gọi criterion trên dự đoán của discriminator cho ảnh thật và một tensor toàn số 1 có cùng shape.