Prioritized experience replay buffer
You will introduce the PrioritizedExperienceReplay
class, a data structure that you will later use to implement DQN with Prioritized Experience Replay.
PrioritizedExperienceReplay
is a refinement over the ExperienceReplay
class that you have been using so far to train your DQN agents. A prioritized experience replay buffer ensures that the transitions sampled from it are more valuable for the agent to learn from than with uniform sampling.
For now, implement the methods .__init__()
, .push()
, .update_priorities()
, .increase_beta()
and .__len__()
. The final method, .sample()
, will be the focus of the next exercise.
Cet exercice fait partie du cours
Deep Reinforcement Learning in Python
Instructions
- In
.push()
, initialize the transition's priority to the maximum priority in the buffer (or 1 if the buffer is empty). - In
.update_priorities()
, set the priority to the absolute value of the corresponding TD error; addself.epsilon
to cover edge cases. - In
.increase_beta()
, increment beta byself.beta_increment
; ensurebeta
never exceeds 1.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
class PrioritizedReplayBuffer:
def __init__(
self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, epsilon=0.01
):
self.memory = deque(maxlen=capacity)
self.alpha, self.beta, self.beta_increment, self.epsilon = (alpha, beta, beta_increment, epsilon)
self.priorities = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
# Initialize the transition's priority
max_priority = ____
self.memory.append(experience_tuple)
self.priorities.append(max_priority)
def update_priorities(self, indices, td_errors):
for idx, td_error in zip(indices, td_errors.tolist()):
# Update the transition's priority
self.priorities[idx] = ____
def increase_beta(self):
# Increase beta if less than 1
self.beta = ____
def __len__(self):
return len(self.memory)
buffer = PrioritizedReplayBuffer(capacity=3)
buffer.push(state=[1,3], action=2, reward=1, next_state=[2,4], done=False)
print("Transition in memory buffer:", buffer.memory)
print("Priority buffer:", buffer.priorities)