ComenzarEmpieza gratis

Muestreo del búfer PER

Antes de que puedas utilizar la clase Tampón de Experiencias Priorizadas para entrenar a tu agente, todavía tienes que implementar el método .sample(). Este método toma como argumento el tamaño de la muestra que quieres dibujar, y devuelve las transiciones muestreadas como tensors, junto con sus índices en el búfer de memoria y su peso de importancia.

Se ha precargado un buffer con capacidad 10 en tu entorno para que tomes muestras de él.

Este ejercicio forma parte del curso

Aprendizaje profundo por refuerzo en Python

Ver curso

Instrucciones de ejercicio

  • Calcula la probabilidad de muestreo asociada a cada transición.
  • Dibuja los índices correspondientes a las transiciones de la muestra; np.random.choice(a, s, p=p) toma una muestra de tamaño s con reemplazamiento de la matriz a, basándose en la matriz de probabilidad p.
  • Calcula el peso de importancia asociado a cada transición.

Ejercicio interactivo práctico

Pruebe este ejercicio completando este código de muestra.

def sample(self, batch_size):
    priorities = np.array(self.priorities)
    # Calculate the sampling probabilities
    probabilities = ____ / np.sum(____)
    # Draw the indices for the sample
    indices = np.random.choice(____)
    # Calculate the importance weights
    weights = (1 / (len(self.memory) * ____)) ** ____
    weights /= np.max(weights)
    states, actions, rewards, next_states, dones = zip(*[self.memory[idx] for idx in indices])
    weights = [weights[idx] for idx in indices]
    states_tensor = torch.tensor(states, dtype=torch.float32)
    rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
    next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
    dones_tensor = torch.tensor(dones, dtype=torch.float32)
    weights_tensor = torch.tensor(weights, dtype=torch.float32)
    actions_tensor = torch.tensor(actions, dtype=torch.long).unsqueeze(1)
    return (states_tensor, actions_tensor, rewards_tensor, next_states_tensor,
            dones_tensor, indices, weights_tensor)

PrioritizedReplayBuffer.sample = sample
print("Sampled transitions:\n", buffer.sample(3))
Editar y ejecutar código