得到:
assert q_values.shape == (len(state_batch), self.nb_actions)
AssertionError
q_values.shape <class 'tuple'>: (1, 1, 10)
(len(state_batch), self.nb_actions) <class 'tuple'>: (1, 10)
来自 sarsa 代理的 keras-rl 库:
rl.agents.sarsa.SARSAAgent#compute_batch_q_values
batch = self.process_state_batch(state_batch)
q_values = self.model.predict_on_batch(batch)
assert q_values.shape == (len(state_batch), self.nb_actions)
这是我的代码:
class MyEnv(Env):
def __init__(self):
self._reset()
def _reset(self) -> None:
self.i = 0
def _get_obs(self) -> List[float]:
return [1] * 20
def reset(self) -> List[float]:
self._reset()
return self._get_obs()
model = Sequential()
model.add(Dense(units=20, activation='relu', input_shape=(1, 20)))
model.add(Dense(units=10, activation='softmax'))
logger.info(model.summary())
policy = BoltzmannQPolicy()
agent = SARSAAgent(model=model, nb_actions=10, policy=policy)
optimizer = Adam(lr=1e-3)
agent.compile(optimizer, metrics=['mae'])
env = MyEnv()
agent.fit(env, 1, verbose=2, visualize=True)
想知道是否有人可以向我解释应该如何设置尺寸以及它如何与库一起使用?我正在输入一个包含 20 个输入的列表,并希望输出为 10。