ComenzarEmpieza gratis

Buffer de repetición de experiencia priorizada

Presentarás la clase PrioritizedExperienceReplay, una estructura de datos que utilizarás más adelante para implementar DQN con la Repetición de Experiencias Priorizadas.

PrioritizedExperienceReplay es un perfeccionamiento de la clase ExperienceReplay que has estado utilizando hasta ahora para entrenar a tus agentes DQN. Un búfer de repetición de experiencias priorizado garantiza que las transiciones muestreadas de él sean más valiosas para que el agente aprenda de ellas que con un muestreo uniforme.

Por ahora, implementa los métodos .__init__(), .push(), .update_priorities(), .increase_beta() y .__len__(). El último método, .sample(), será el objeto del siguiente ejercicio.

Este ejercicio forma parte del curso

Aprendizaje profundo por refuerzo en Python

Ver curso

Instrucciones de ejercicio

  • En .push(), inicializa la prioridad de la transición a la máxima prioridad del búfer (o a 1 si el búfer está vacío).
  • En .update_priorities(), ajusta la prioridad al valor absoluto del error correspondiente de TD; añade self.epsilon para cubrir los casos límite.
  • En .increase_beta(), incrementa beta en self.beta_increment; asegúrate de que beta nunca sea superior a 1.

Ejercicio interactivo práctico

Pruebe este ejercicio completando este código de muestra.

class PrioritizedReplayBuffer:
    def __init__(
        self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, epsilon=0.01
    ):
        self.memory = deque(maxlen=capacity)
        self.alpha, self.beta, self.beta_increment, self.epsilon = (alpha, beta, beta_increment, epsilon)
        self.priorities = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        experience_tuple = (state, action, reward, next_state, done)
        # Initialize the transition's priority
        max_priority = ____
        self.memory.append(experience_tuple)
        self.priorities.append(max_priority)
    
    def update_priorities(self, indices, td_errors):
        for idx, td_error in zip(indices, td_errors.tolist()):
            # Update the transition's priority
            self.priorities[idx] = ____

    def increase_beta(self):
        # Increase beta if less than 1
        self.beta = ____

    def __len__(self):
        return len(self.memory)
      
buffer = PrioritizedReplayBuffer(capacity=3)
buffer.push(state=[1,3], action=2, reward=1, next_state=[2,4], done=False)
print("Transition in memory buffer:", buffer.memory)
print("Priority buffer:", buffer.priorities)
Editar y ejecutar código