Fixed Q-targets
Anda sedang mempersiapkan pelatihan Lunar Lander dengan fixed Q-targets. Sebagai prasyarat, Anda perlu membuat instance online network (yang memilih aksi) dan target network (digunakan untuk perhitungan TD-target).
Anda juga perlu mengimplementasikan fungsi update_target_network yang dapat digunakan pada setiap langkah pelatihan. Target network tidak diperbarui dengan gradient descent; sebagai gantinya, update_target_network mendorong bobotnya mendekati Q-network dalam jumlah kecil, sehingga tetap cukup stabil dari waktu ke waktu.
Perhatikan bahwa, khusus untuk latihan ini, Anda menggunakan jaringan yang sangat kecil agar kita dapat dengan mudah mencetak dan memeriksa state dictionary-nya. Jaringan ini hanya memiliki satu hidden layer berukuran dua; action space dan state space-nya juga berdimensi 2.
Fungsi print_state_dict() tersedia di lingkungan Anda untuk mencetak state dict.
Latihan ini adalah bagian dari kursus
Deep Reinforcement Learning dengan Python
Petunjuk latihan
- Peroleh
.state_dict()untuk target network dan online network. - Perbarui state dict untuk target network dengan mengambil rata-rata tertimbang antara parameter online network dan target network, menggunakan
tausebagai bobot untuk online network. - Muat kembali state dict yang telah diperbarui ke target network.
Latihan interaktif praktis
Cobalah latihan ini dengan menyelesaikan kode contoh berikut.
def update_target_network(target_network, online_network, tau):
# Obtain the state dicts for both networks
target_net_state_dict = ____
online_net_state_dict = ____
for key in online_net_state_dict:
# Calculate the updated state dict for the target network
target_net_state_dict[key] = (online_net_state_dict[____] * ____ + target_net_state_dict[____] * ____)
# Load the updated state dict into the target network
target_network.____
return None
print("online network weights:", print_state_dict(online_network))
print("target network weights (pre-update):", print_state_dict(target_network))
update_target_network(target_network, online_network, .001)
print("target network weights (post-update):", print_state_dict(target_network))