Buffer de reprodução de experiência priorizada
Você apresentará a classe PrioritizedExperienceReplay
, uma estrutura de dados que será usada posteriormente para implementar o DQN com Prioritized Experience Replay.
PrioritizedExperienceReplay
é um refinamento da classe ExperienceReplay
que você tem usado até agora para treinar seus agentes DQN. Um buffer de reprodução de experiência priorizado garante que as transições amostradas nele sejam mais valiosas para o agente aprender do que com a amostragem uniforme.
Por enquanto, implemente os métodos .__init__()
, .push()
, .update_priorities()
, .increase_beta()
e .__len__()
. O método final, .sample()
, será o foco do próximo exercício.
Este exercício faz parte do curso
Aprendizado por reforço profundo em Python
Instruções de exercício
- Em
.push()
, inicialize a prioridade da transição com a prioridade máxima no buffer (ou 1 se o buffer estiver vazio). - Em
.update_priorities()
, defina a prioridade como o valor absoluto do erro correspondente em TD; adicioneself.epsilon
para cobrir casos extremos. - Em
.increase_beta()
, incremente beta emself.beta_increment
; certifique-se de quebeta
nunca exceda 1.
Exercício interativo prático
Experimente este exercício preenchendo este código de exemplo.
class PrioritizedReplayBuffer:
def __init__(
self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, epsilon=0.01
):
self.memory = deque(maxlen=capacity)
self.alpha, self.beta, self.beta_increment, self.epsilon = (alpha, beta, beta_increment, epsilon)
self.priorities = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
# Initialize the transition's priority
max_priority = ____
self.memory.append(experience_tuple)
self.priorities.append(max_priority)
def update_priorities(self, indices, td_errors):
for idx, td_error in zip(indices, td_errors.tolist()):
# Update the transition's priority
self.priorities[idx] = ____
def increase_beta(self):
# Increase beta if less than 1
self.beta = ____
def __len__(self):
return len(self.memory)
buffer = PrioritizedReplayBuffer(capacity=3)
buffer.push(state=[1,3], action=2, reward=1, next_state=[2,4], done=False)
print("Transition in memory buffer:", buffer.memory)
print("Priority buffer:", buffer.priorities)