这是我的代码,它是一个简单的 DQN,可以学习玩蛇,例如,我不知道为什么它会在一段时间后停止学习。它知道蛇头应该撞墙,但它没有学会吃水果,即使我给靠近水果的奖励和更远的负奖励(这是为了让蛇明白)它应该瞄准水果)。但由于某种原因,分数永远不会超过 1 或 2:“””############################### ##########################MAIN.py
# -*- coding: utf-8 -*-
"""
Created on Mon Aug 10 13:04:45 2020
@author: Ryan
"""
from dq_learning import Agent
import numpy as np
import tensorflow as tf
import snake
import sys
import pygame
import gym
if __name__ == '__main__':
observation_space = 31
action_space = 4
lr = 0.001
n_games = 50000
steps = 1000
#env = gym.make("LunarLander-v2")
#observation_space = env.observation_space.shape
#action_space = env.action_space.n
agent = Agent(gamma=0.99, epsilon=1.0, lr=lr,
input_dims=observation_space,
n_actions=action_space,
batch_size=64)
scores = []
eps_history = []
r = False
for i in range(n_games):
score = 0
#first observation
observation = [0 for i in range(observation_space)]
#observation = env.reset()
for j in range(steps):
# env.render()
for evt in pygame.event.get():
if evt.type == pygame.QUIT:
pygame.quit()
sys.exit()
#actions go from 0 to n_actions - based on the model prediction or random choice
#action space is the list of all the possible actions
action = agent.choose_action(observation)
#print("action: ", action)
#env.step(action) returns -> new observation, reward, done, info
observation_, reward, done, info = snake.step(action, 25)
#observation_, reward, done, info = env.step(action)
#print(observation_, reward, done, info)
score += reward
agent.store_transition(observation, action, reward, observation_, done)
observation = observation_
agent.learn()
if done:
break
print("NEXT GAME")
done = False
eps_history.append(agent.epsilon)
scores.append(score)
avg_score = np.mean(scores[-100:])
print("episode: ", i, " scores %.2f" %score,
"average score: %.2f" %avg_score,
" epsilon %.2f" %agent.epsilon)
print("last score: ", scores[-1])
#####################################
#DQ_LEARNING.PY
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 4 12:23:14 2020
@author: Ryan
"""
import numpy as np
import tensorflow as tf
from tensorflow import keras
class ReplayBuffer:
def __init__(self, max_size, input_dims):
self.mem_size = max_size
self.mem_cntr = 0
"""
print("self.mem_size: ", self.mem_size)
print("*input_dims: ", *input_dims)
"""
self.state_memory = np.zeros((self.mem_size, input_dims), dtype=np.float32)
self.new_state_memory = np.zeros((self.mem_size, input_dims), dtype=np.float32)
self.action_memory = np.zeros(self.mem_size, np.int32)
self.reward_memory = np.zeros(self.mem_size, np.float32)
self.terminal_memory = np.zeros(self.mem_size, np.int32) #done flags
def store_transitions(self, state, action, reward, state_, done):
"""print("storing transactions...")
print("mem_cntr: ", self.mem_cntr)
print("mem_size: ", self.mem_size)
"""
index = self.mem_cntr % self.mem_size
self.state_memory[index] = state
self.new_state_memory[index] = state_
self.reward_memory[index] = reward
self.action_memory[index] = action
self.terminal_memory[index] = 1 - int(done)
self.mem_cntr += 1
def sample_buffer(self, batch_size):
#print("sampling buffer...")
max_mem = min(self.mem_cntr, self.mem_size)
batch = np.random.choice(max_mem, batch_size, replace=False)
#print("batch:", batch)
states = self.state_memory[batch]
states_ = self.new_state_memory[batch]
rewards = self.reward_memory[batch]
actions = self.action_memory[batch]
terminal = self.terminal_memory[batch]
#print("self.action_mem: ", self.action_memory)
#print("actions: ", actions)
#print("state action rewards state_, terminal", (states, actions, rewards, states_, terminal))
return states, actions, rewards, states_, terminal
def build_dqn(lr, n_actions, input_dims, fc1_dims, fc2_dims):
model = keras.Sequential()
model.add(keras.layers.Dense(fc1_dims, activation='relu'))
model.add(keras.layers.Dense(fc2_dims, activation='relu'))
model.add(keras.layers.Dense(n_actions))
opt = keras.optimizers.Adam(learning_rate=lr)
model.compile(optimizer=opt, loss='mean_squared_error')
return model
class Agent():
def __init__(self, lr, gamma, n_actions, epsilon, batch_size,
input_dims, epsilon_dec=1e-3, epsilon_end=1e-2,
mem_size=1e6, fname='dqn_model.h5'):
self.action_space = [i for i in range(n_actions)]
self.gamma = gamma
self.epsilon = epsilon
self.eps_min = epsilon_end
self.eps_dec = epsilon_dec
self.batch_size = batch_size
self.model_file = fname
self.memory = ReplayBuffer(int(mem_size), input_dims)
self.q_eval = build_dqn(lr, n_actions, input_dims, 256, 256)
def store_transition(self, state, action, reward, new_state, done):
self.memory.store_transitions(state, action, reward, new_state, done)
def choose_action(self, observation):
if np.random.random() < self.epsilon:
action = np.random.choice(self.action_space)
else:
state = np.array([observation])
actions = self.q_eval.predict(state)
action = np.argmax(actions)
return action
def learn(self):
if self.memory.mem_cntr < self.batch_size:
return
states, actions, rewards, states_, dones = \
self.memory.sample_buffer(self.batch_size)
q_eval = self.q_eval.predict(states)
q_next = self.q_eval.predict(states_)
q_target = np.copy(q_eval)
batch_index = np.arange(self.batch_size, dtype=np.int32)
q_target[batch_index, actions] = rewards + \
self.gamma * np.max(q_next, axis=1)*dones
self.q_eval.train_on_batch(states, q_target)
self.epsilon = self.epsilon - self.eps_dec if self.epsilon > \
self.eps_min else self.eps_min
def save_model(self):
self.q_eval.save(self.model_file)
def load_model(self):
self.q_eval = keras.models.load_model(self.model_file)
##########################################
# snake.py
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 4 14:32:30 2020
@author: Ryan
"""
import pygame
import random
from math import sqrt
import time
class Snakehead:
def __init__(self, posx, posy, width, height):
self.posx = posx
self.posy = posy
self.width = width
self.height = height
self.movement = 'null'
self.speed = 16
self.gameover = False
def draw(self, Display): #RGB #coordinates/dimentions
pygame.draw.rect(Display, [0, 0, 0], [self.posx, self.posy, self.width, self.height])
def read_input(self, key):
if key == 0 and key != 1:
self.movement = 'left'
elif key == 1 and key != 0:
self.movement = 'right'
elif key == 2 and key != 3:
self.movement = 'up'
elif key == 3 and key != 2:
self.movement = 'down'
print(self.movement)
def get_pos(self):
return self.posx, self.posy
def get_movement(self):
return self.movement
def restart(self, ScreenW, ScreenH):
self.posx = ScreenW / 2 - 16/2
self.posy = ScreenH / 2 - 16/2
def move(self, SW, SH):
if self.movement == 'right':
self.posx += self.speed # self.posx = self.posx + self.speed
elif self.movement == 'left':
self.posx -= self.speed # self.posx = self.posx - self.speed
elif self.movement == 'up':
self.posy -= self.speed # self.posy = self.posy - self.speed
elif self.movement == 'down':
self.posy += self.speed # self.posy = self.posy + self.speed
class Food:
def __init__(self, posx, posy, width, height):
self.posx = posx
self.posy = posy
self.width = width
self.height = height
self.red = random.randint(155, 255)
def draw(self, Display):
pygame.draw.rect(Display, [self.red, 0, 0], [self.posx, self.posy, self.width, self.height])
def get_pos(self):
return self.posx, self.posy
def respawn(self, ScreenW, ScreenH):
self.posx = random.randint(1, (ScreenW - 16)/16) * 16
self.posy = random.randint(1, (ScreenH - 16)/16) * 16
self.red = random.randint(155, 255)
class Tail:
def __init__(self, posx, posy, width, height):
self.width = width
self.height = height
self.posx = posx
self.posy = posy
self.RGB = [random.randint(0, 255) for i in range(3)]
def draw(self, Diplay):
pygame.draw.rect(Diplay, self.RGB, [self.posx, self.posy, 16, 16])
def move(self, px, py):
self.posx = px
self.posy = py
def get_pos(self):
return self.posx, self.posy
ScreenW = 720
ScreenH = 720
sheadX = 0
sheadY = 0
fX = 0
fY = 0
counter = 0
pygame.init()
pygame.display.set_caption("Snake Game")
Display = pygame.display.set_mode([ScreenW, ScreenH])
Display.fill([255, 255, 255]) #RGB white
black = [0, 0, 0]
font = pygame.font.SysFont(None, 30)
score = font.render("Score: 0", True, black)
shead = Snakehead(ScreenW / 2 - 16/2, ScreenH / 2 - 16/2, 16, 16)
f = Food(random.randint(0, (ScreenW - 16)/16) * 16 - 8, random.randint(0, (ScreenH - 16)/16) * 16, 16, 16)
tails = []
Fps = 60
timer_clock = pygame.time.Clock()
previous_distance = 0
d = 0
def step(action, observation_space):
global score, counter, tails, shead, gameover, previous_distance, d
shead.gameover = False
observation_, reward, done, info = [0 for i in range(observation_space+6)], 0, 0, 0
Display.fill([255, 255, 255])
shead.read_input(action)
sheadX, sheadY = shead.get_pos()
fX, fY = f.get_pos()
#detect collision
if sheadX + 16 > fX and sheadX < fX + 16:
if sheadY + 16 > fY and sheadY < fY + 16:
#collision
f.respawn(ScreenW, ScreenH)
counter += 1 # counter = counter + 1
score = font.render("Score: " + str(counter), True, black)
if len(tails) == 0:
tails.append(Tail(sheadX, sheadY, 16, 16))
#tails.append(tail.Tail(sheadX, sheadY, 16, 16, shead.get_movement()))
else:
tX, tY = tails[-1].get_pos()
tails.append(Tail(tX, tY, 16, 16))
reward = 100
print(tails)
for i in range(len(tails)):
try:
tX, tY = tails[i].get_pos()
#print("tx: ", tX, " ty: ", tY)
sX, sY = shead.get_pos()
#print("Sx: ", sX, " sy: ", sY)
if i != 0 and i != 1:
#print("more than 2 tails")
if tX == sX and tY == sY:
print("collision")
#collision
shead.restart(ScreenW, ScreenH)
tails.clear()
counter = 0
Display.blit(score, (10, 10))
pygame.display.flip()
pygame.display.update()
reward = -300
shead.gameover = True
print("lost-3")
except:
shead.restart(ScreenW, ScreenH)
tails.clear()
counter = 0
reward = -300
shead.gameover = True
print("lost-0")
sX, sY = shead.get_pos()
if sX < 0 or sX + 16 > ScreenW:
shead.restart(1280, 720)
counter = 0
Display.blit(score, (10, 10))
pygame.display.flip()
pygame.display.update()
tails.clear()
print("lost-1")
reward = -200
shead.gameover = True
#restart
elif sY < 0 or sY + 16 > ScreenH:
shead.restart(1280, 720)
counter = 0
Display.blit(score, (10, 10))
pygame.display.flip()
pygame.display.update()
tails.clear()
reward = -200
shead.gameover = True
print("lost-2")
#restart
for i in range(1, len(tails)):
tX, tY = tails[len(tails) - i - 1].get_pos() # y = b - x
tails[len(tails) - i].move(tX, tY)
if len(tails) > 0:
tX, tY = shead.get_pos()
tails[0].move(tX, tY)
shead.move(ScreenW, ScreenH)
shead.draw(Display)
Display.blit(score, (10, 10))
for tail in tails:
tail.draw(Display)
f.draw(Display)
pygame.display.flip()
pygame.display.update()
timer_clock.tick(Fps)
#observation, done
done = shead.gameover
hx, hy = shead.get_pos()
hx /= ScreenW
hy /= ScreenH
fx, fy = f.get_pos()
fx /= ScreenW
fy /= ScreenH
observation_[0] = abs(hx - fx)
observation_[1] = abs(hy - fy)
previous_distance = d
d = sqrt((fx - hx)**2 + (fy - hy)**2)
#print("distance: ", d)
observation_[2] = d
observation_[3] = 0
#print("observation_[4]: ", observation_[4])
observation_[4] = hx
observation_[5] = hy
c = 6
xlist = []
ylist = []
for t in tails:
tx, ty = t.get_pos()
tx /= 16
ty /= 16
xlist.append(tx)
ylist.append(ty)
l = int(sqrt(observation_space))
startX, startY = shead.get_pos()
startX /= 16
startY /= 16
m = (l-1)/2
#print("xlist:" , xlist)
#print("ylist:", ylist)
#print("startX: ", startX)
#print("startY: ", startY)
#print("m: ", m)
#print("l: ", l)
for x in range(l):
for y in range(l):
found = False
#print("position: (", int(startX) - m + x, ",", int(startY) - m + y, ")")
for i in range(len(xlist)):
"""print("i:", i)
print("pos: ", startX - m + x)
print("j: ", j)
print("pos: ", startY - m + y)
"""
#print("current iteration: (", int(xlist[i]), ",", int(ylist[i]), ")")
if int(xlist[i]) == int(startX) - m + x and int(ylist[i]) == int(startY) - m + y:
#print("found a match")
observation_[c] = 1
#print("c is: ", c)
#print("observation_[c] is: ", observation_[c])
found = True
break
if not found:
#print("set to 0")
observation_[c] = 0
#print("increasing c...")
c += 1
print("reward: ", reward)
print("c_reward: ", counter*10)
d_reward = 10 if d < previous_distance else - 100
print("d_reward: ", d_reward)
print(observation_, reward + d_reward + counter*10, done, 0)
return observation_, reward, done, 0