3

我有一个关于这个SARSA FA 的问题。

在输入单元格 142 中,我看到了这个修改后的更新

w += alpha * (reward - discount * q_hat_next) * q_hat_grad

其中q_hat_nextQ(S', a')q_hat_gradQ(S, a)(假设S, a, R, S' a'序列)的导数。

我的问题是更新不应该是这样的吗?

w += alpha * (reward + discount * q_hat_next - q_hat) * q_hat_grad

修改后的更新背后的直觉是什么?

4

1 回答 1

0

我认为你是对的。我还希望更新包含 TD 错误术语,它应该是reward + discount * q_hat_next - q_hat.

作为参考,这是实现:

if done: # (terminal state reached)
   w += alpha*(reward - q_hat) * q_hat_grad
   break
else:
   next_action = policy(env, w, next_state, epsilon)
   q_hat_next = approx(w, next_state, next_action)
   w += alpha*(reward - discount*q_hat_next)*q_hat_grad
   state = next_state

这是来自强化学习的伪代码:简介(Sutton & Barto 撰写)(第 171 页):

在此处输入图像描述

由于实现是TD(0),n所以是1。那么伪代码中的更新可以简化为:

w <- w + a[G - v(S_t,w)] * dv(S_t,w)

变为(通过替换G == reward + discount*v(S_t+1,w))

w <- w + a[reward + discount*v(S_t+1,w) - v(S_t,w)] * dv(S_t,w)

或者使用原始代码示例中的变量名称:

w += alpha * (reward + discount * q_hat_next - q_hat) * q_hat_grad

我最终得到了与您相同的更新公式。看起来像非终端状态更新中的错误。

只有最终情况(如果done为真)应该是正确的,因为q_hat_next根据定义,then 始终为 0,因为情节结束并且无法获得更多奖励。

于 2018-08-26T22:21:00.683 回答