我最近刚刚尝试使用 Keras-RL 在一个tictactoe 游戏中训练一个代理,我为我的最后三年项目练习制作自定义环境,该项目涉及这样做,但在一个更大的适当游戏中。
在接下来的步骤中,我收到了一个错误,我尝试用谷歌搜索它,但我找到的所有答案都是针对具体情况的(或者我可能只是不擅长谷歌搜索):
dqn = build_agent(model, actions)
dqn.compile(Adam(lr=1e-3), metrics=["mae"])
dqn.fit(env, nb_steps=50000, visualize=False, verbose=1)
我正在使用以下内容来构建模型:
env = TTTEnv()
states = env.observation_space.shape
actions = env.action_space.n
def build_model(states, actions):
model = Sequential()
model.add(Dense(24, activation="relu",input_shape=states))
model.add(Dense(24, activation="relu"))
model.add(Dense(actions, activation="linear"))
return model
def build_agent(model, actions):
policy = BoltzmannQPolicy()
memory = SequentialMemory(limit=50000, window_length=1)
dqn = DQNAgent(model=model, memory=memory, policy=policy,
nb_actions=actions, nb_steps_warmup=10,
target_model_update=1e-2)
return dqn
这是与我自己制作的井字游戏交互的自定义环境:
class TTTEnv(Env):
def __init__(self):
self.action_space = Discrete(9)
# Caused problems with keras-rl so I resorted to flattening it.
#self.observation_space = np.array([[Discrete(3)]*3,[Discrete(3)]*3,[Discrete(3)]*3])
self.observation_space = np.array([Discrete(3)]*9)
self.game = Game()
self.state = self.game.gameArray.flatten()
def step(self, action):
reward = 0
done = False
self.game.printGame()
position = self.game.inputs[action]
if self.game.gameArray[position[0],position[1]] != 0:
reward -= 20
done = True
else:
self.game.gameArray[position[0],position[1]] = 1
gameOver, winner = self.game.checkWinGYM()
if winner == "win":
reward += 50
done = gameOver
elif winner == "draw":
reward += 10
elif winner == "ingame":
self.game.handleBotTurn()
gameOver, winner = self.game.checkWinGYM()
if winner == "loss":
done = gameOver
reward -= 50
elif winner == "draw":
done = gameOver
reward += 10
info = {}
return self.game.gameArray.flatten(), reward, done, info
def render(self):
pass
def reset(self):
self.state = np.array([[0,0,0],[0,0,0],[0,0,0]])
self.game.resetGameArray()
return self.state
我知道我的代码不是最干净的,所以请原谅我。我只是想快速拼凑一些东西以达到我的真正目标;我的最终项目。如果你想要更多的代码,请告诉我,我会把它扔进去。
谢谢!
编辑:添加错误:
“ValueError:检查输入时出错:预期的 dense_9_input 有 2 个维度,但得到的数组的形状为 (1, 1, 3, 3)”