ComeçarComece gratuitamente

Buffer de reprodução de experiência priorizada

Você apresentará a classe PrioritizedExperienceReplay, uma estrutura de dados que será usada posteriormente para implementar o DQN com Prioritized Experience Replay.

PrioritizedExperienceReplay é um refinamento da classe ExperienceReplay que você tem usado até agora para treinar seus agentes DQN. Um buffer de reprodução de experiência priorizado garante que as transições amostradas nele sejam mais valiosas para o agente aprender do que com a amostragem uniforme.

Por enquanto, implemente os métodos .__init__(), .push(), .update_priorities(), .increase_beta() e .__len__(). O método final, .sample(), será o foco do próximo exercício.

Este exercício faz parte do curso

Aprendizado por reforço profundo em Python

Ver Curso

Instruções de exercício

  • Em .push(), inicialize a prioridade da transição com a prioridade máxima no buffer (ou 1 se o buffer estiver vazio).
  • Em .update_priorities(), defina a prioridade como o valor absoluto do erro correspondente em TD; adicione self.epsilon para cobrir casos extremos.
  • Em .increase_beta(), incremente beta em self.beta_increment; certifique-se de que beta nunca exceda 1.

Exercício interativo prático

Experimente este exercício preenchendo este código de exemplo.

class PrioritizedReplayBuffer:
    def __init__(
        self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, epsilon=0.01
    ):
        self.memory = deque(maxlen=capacity)
        self.alpha, self.beta, self.beta_increment, self.epsilon = (alpha, beta, beta_increment, epsilon)
        self.priorities = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        experience_tuple = (state, action, reward, next_state, done)
        # Initialize the transition's priority
        max_priority = ____
        self.memory.append(experience_tuple)
        self.priorities.append(max_priority)
    
    def update_priorities(self, indices, td_errors):
        for idx, td_error in zip(indices, td_errors.tolist()):
            # Update the transition's priority
            self.priorities[idx] = ____

    def increase_beta(self):
        # Increase beta if less than 1
        self.beta = ____

    def __len__(self):
        return len(self.memory)
      
buffer = PrioritizedReplayBuffer(capacity=3)
buffer.push(state=[1,3], action=2, reward=1, next_state=[2,4], done=False)
print("Transition in memory buffer:", buffer.memory)
print("Priority buffer:", buffer.priorities)
Editar e executar código