1. 学ぶ
  2. /
  3. コース
  4. /
  5. Pythonで学ぶDeep Reinforcement Learning

Connected

演習

最小構成のDQN損失関数

select_action() 関数の準備ができたので、エージェントの学習まであと一歩です。ここでは calculate_loss() を実装します。

calculate_loss() は、エピソード中の任意のステップに対するネットワークの損失を返します。

参考までに、損失は次の式で与えられます。

この演習には次のサンプルデータが読み込まれています。

state = torch.rand(8)
next_state = torch.rand(8)
action = select_action(q_network, state)
reward = 1
gamma = .99
done = False

指示

100 XP
  • 現在状態のQ値を取得します。
  • 次状態のQ値を取得します。
  • 目標Q値(TDターゲット)を計算します。
  • 損失関数(ベルマン誤差の二乗)を計算します。