MulaiMulai sekarang secara gratis

Fixed Q-targets

Anda sedang mempersiapkan pelatihan Lunar Lander dengan fixed Q-targets. Sebagai prasyarat, Anda perlu membuat instance online network (yang memilih aksi) dan target network (digunakan untuk perhitungan TD-target).

Anda juga perlu mengimplementasikan fungsi update_target_network yang dapat digunakan pada setiap langkah pelatihan. Target network tidak diperbarui dengan gradient descent; sebagai gantinya, update_target_network mendorong bobotnya mendekati Q-network dalam jumlah kecil, sehingga tetap cukup stabil dari waktu ke waktu.

Perhatikan bahwa, khusus untuk latihan ini, Anda menggunakan jaringan yang sangat kecil agar kita dapat dengan mudah mencetak dan memeriksa state dictionary-nya. Jaringan ini hanya memiliki satu hidden layer berukuran dua; action space dan state space-nya juga berdimensi 2.

Fungsi print_state_dict() tersedia di lingkungan Anda untuk mencetak state dict.

Latihan ini adalah bagian dari kursus

Deep Reinforcement Learning dengan Python

Lihat Kursus

Petunjuk latihan

  • Peroleh .state_dict() untuk target network dan online network.
  • Perbarui state dict untuk target network dengan mengambil rata-rata tertimbang antara parameter online network dan target network, menggunakan tau sebagai bobot untuk online network.
  • Muat kembali state dict yang telah diperbarui ke target network.

Latihan interaktif praktis

Cobalah latihan ini dengan menyelesaikan kode contoh berikut.

def update_target_network(target_network, online_network, tau):
    # Obtain the state dicts for both networks
    target_net_state_dict = ____
    online_net_state_dict = ____
    for key in online_net_state_dict:
        # Calculate the updated state dict for the target network
        target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
        # Load the updated state dict into the target network
        target_network.____
    return None
  
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))
Edit dan Jalankan Kode