Buffer experience replay
Sekarang Anda akan membuat struktur data untuk mendukung Experience Replay, yang memungkinkan agen Anda belajar jauh lebih efisien.
Replay buffer ini harus mendukung dua operasi:
- Menyimpan pengalaman dalam memorinya untuk pengambilan sampel di masa depan.
- "Memutar ulang" sekumpulan pengalaman masa lalu yang diambil secara acak dari memorinya.
Karena data yang diambil dari replay buffer akan digunakan sebagai masukan ke jaringan neural, buffer harus mengembalikan Tensor torch untuk kemudahan.
Modul torch dan random serta kelas deque telah diimpor ke lingkungan latihan Anda.
Latihan ini adalah bagian dari kursus
Deep Reinforcement Learning dengan Python
Petunjuk latihan
- Lengkapi metode
push()dariReplayBufferdengan menambahkanexperience_tupleke memori buffer. - Dalam metode
sample(), ambil sampel acak berukuranbatch_sizedariself.memory. - Masih di
sample(), sampel awalnya berupa daftar tuple; pastikan diubah menjadi sebuah tuple berisi daftar. - Ubah
actions_tensormenjadi bentuk(batch_size, 1)alih-alih(batch_size).
Latihan interaktif praktis
Cobalah latihan ini dengan menyelesaikan kode contoh berikut.
class ReplayBuffer:
def __init__(self, capacity):
self.memory = deque([], maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
# Append experience_tuple to the memory buffer
self.memory.____
def __len__(self):
return len(self.memory)
def sample(self, batch_size):
# Draw a random sample of size batch_size
batch = ____(____, ____)
# Transform batch into a tuple of lists
states, actions, rewards, next_states, dones = ____
states_tensor = torch.tensor(states, dtype=torch.float32)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
dones_tensor = torch.tensor(dones, dtype=torch.float32)
# Ensure actions_tensor has shape (batch_size, 1)
actions_tensor = torch.tensor(actions, dtype=torch.long).____
return states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor