Mulai sekarangMulai gratis

Sampling dari buffer PER

Sebelum Anda dapat menggunakan kelas Prioritized Experience Buffer untuk melatih agen, Anda masih perlu mengimplementasikan metode .sample(). Metode ini menerima argumen berupa ukuran sampel yang ingin Anda ambil, dan mengembalikan transisi yang diambil sebagai tensors, beserta indeksnya dalam buffer memori dan bobot kepentingannya.

Sebuah buffer dengan kapasitas 10 telah dimuat sebelumnya di lingkungan Anda untuk diambil sampelnya.

Latihan ini merupakan bagian dari kursus

Deep Reinforcement Learning dengan Python

Lihat Kursus

Instruksi latihan

  • Hitung probabilitas pengambilan sampel yang terkait dengan setiap transisi.
  • Ambil indeks yang sesuai dengan transisi dalam sampel; np.random.choice(a, s, p=p) mengambil sampel berukuran s dengan pengembalian dari array a, berdasarkan array probabilitas p.
  • Hitung bobot kepentingan yang terkait dengan setiap transisi.

Latihan interaktif langsung praktik

Cobalah latihan ini dengan melengkapi kode contoh ini.

def sample(self, batch_size):
    priorities = np.array(self.priorities)
    # Calculate the sampling probabilities
    probabilities = ____ / np.sum(____)
    # Draw the indices for the sample
    indices = np.random.choice(____)
    # Calculate the importance weights
    weights = (1 / (len(self.memory) * ____)) ** ____
    weights /= np.max(weights)
    states, actions, rewards, next_states, dones = zip(*[self.memory[idx] for idx in indices])
    weights = [weights[idx] for idx in indices]
    states_tensor = torch.tensor(states, dtype=torch.float32)
    rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
    next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
    dones_tensor = torch.tensor(dones, dtype=torch.float32)
    weights_tensor = torch.tensor(weights, dtype=torch.float32)
    actions_tensor = torch.tensor(actions, dtype=torch.long).unsqueeze(1)
    return (states_tensor, actions_tensor, rewards_tensor, next_states_tensor,
            dones_tensor, indices, weights_tensor)

PrioritizedReplayBuffer.sample = sample
print("Sampled transitions:\n", buffer.sample(3))
Edit dan Jalankan Kode