MulaiMulai sekarang secara gratis

Sampling dari buffer PER

Sebelum Anda dapat menggunakan kelas Prioritized Experience Buffer untuk melatih agen, Anda masih perlu mengimplementasikan metode .sample(). Metode ini menerima argumen berupa ukuran sampel yang ingin Anda ambil, dan mengembalikan transisi yang diambil sebagai tensors, beserta indeksnya dalam buffer memori dan bobot kepentingannya.

Sebuah buffer dengan kapasitas 10 telah dimuat sebelumnya di lingkungan Anda untuk diambil sampelnya.

Latihan ini adalah bagian dari kursus

Deep Reinforcement Learning dengan Python

Lihat Kursus

Petunjuk latihan

  • Hitung probabilitas pengambilan sampel yang terkait dengan setiap transisi.
  • Ambil indeks yang sesuai dengan transisi dalam sampel; np.random.choice(a, s, p=p) mengambil sampel berukuran s dengan pengembalian dari array a, berdasarkan array probabilitas p.
  • Hitung bobot kepentingan yang terkait dengan setiap transisi.

Latihan interaktif praktis

Cobalah latihan ini dengan menyelesaikan kode contoh berikut.

def sample(self, batch_size):
    priorities = np.array(self.priorities)
    # Calculate the sampling probabilities
    probabilities = ____ / np.sum(____)
    # Draw the indices for the sample
    indices = np.random.choice(____)
    # Calculate the importance weights
    weights = (1 / (len(self.memory) * ____)) ** ____
    weights /= np.max(weights)
    states, actions, rewards, next_states, dones = zip(*[self.memory[idx] for idx in indices])
    weights = [weights[idx] for idx in indices]
    states_tensor = torch.tensor(states, dtype=torch.float32)
    rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
    next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
    dones_tensor = torch.tensor(dones, dtype=torch.float32)
    weights_tensor = torch.tensor(weights, dtype=torch.float32)
    actions_tensor = torch.tensor(actions, dtype=torch.long).unsqueeze(1)
    return (states_tensor, actions_tensor, rewards_tensor, next_states_tensor,
            dones_tensor, indices, weights_tensor)

PrioritizedReplayBuffer.sample = sample
print("Sampled transitions:\n", buffer.sample(3))
Edit dan Jalankan Kode