Barebone DQN action selection
The select_action()
function lets the agent select the action with highest Q-value at every step.
The function takes as argument the Q-network and the current state, and returns the index of the action with highest Q-value.
The Q-network is instantiated as q_network
, and a random state has been loaded in your environment with state = torch.rand(8)
to give you example data to work with.
This exercise is part of the course
Deep Reinforcement Learning in Python
Exercise instructions
- Calculate the Q-values corresponding to each action in the state provided as argument.
- Obtain the index corresponding to the action with highest Q-value.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
def select_action(q_network, state):
# Calculate the Q-values
q_values = ____
print("Q-values:", [round(x, 2) for x in q_values.tolist()])
# Obtain the action index with highest Q-value
action = torch.____.item()
print(f"Action selected: {action}, with q-value {q_values[action]:.2f}")
return action
select_action(q_network, state)