我认为你是对的。我还希望更新包含 TD 错误术语,它应该是reward + discount * q_hat_next - q_hat
.
作为参考,这是实现:
if done: # (terminal state reached)
w += alpha*(reward - q_hat) * q_hat_grad
break
else:
next_action = policy(env, w, next_state, epsilon)
q_hat_next = approx(w, next_state, next_action)
w += alpha*(reward - discount*q_hat_next)*q_hat_grad
state = next_state
这是来自强化学习的伪代码:简介(Sutton & Barto 撰写)(第 171 页):
由于实现是TD(0),n
所以是1。那么伪代码中的更新可以简化为:
w <- w + a[G - v(S_t,w)] * dv(S_t,w)
变为(通过替换G == reward + discount*v(S_t+1,w))
)
w <- w + a[reward + discount*v(S_t+1,w) - v(S_t,w)] * dv(S_t,w)
或者使用原始代码示例中的变量名称:
w += alpha * (reward + discount * q_hat_next - q_hat) * q_hat_grad
我最终得到了与您相同的更新公式。看起来像非终端状态更新中的错误。
只有最终情况(如果done
为真)应该是正确的,因为q_hat_next
根据定义,then 始终为 0,因为情节结束并且无法获得更多奖励。