Minor changes
This commit is contained in:
parent
59ed0cdf73
commit
85a180809d
3 changed files with 8380 additions and 160 deletions
169
qlearning.py
169
qlearning.py
|
@ -5,47 +5,45 @@ import numpy as np
|
|||
# Import snake game
|
||||
from snake import Snake
|
||||
|
||||
class QTable:
|
||||
"""
|
||||
# Boolean features:
|
||||
# Snake go up?
|
||||
# Snake go right?
|
||||
# Snake go down?
|
||||
# Snake go left?
|
||||
# Apple at up?
|
||||
# Apple at right?
|
||||
# Apple at down?
|
||||
# Apple at left?
|
||||
# Obstacle at up?
|
||||
# Obstacle at right?
|
||||
# Obstacle at down?
|
||||
# Obstacle at left?
|
||||
# Tail in front?
|
||||
##### Totally 13 boolean features so 2^13=8192 states
|
||||
##### Totally 4 actions for the AI (up, right,down,left)
|
||||
##### Totally 4*2^13 thus 32768 table entries
|
||||
##### Reward +1 when eat an apple
|
||||
##### Reward -10 when hit obstacle
|
||||
"""
|
||||
def __init__(self, file, save_every=10):
|
||||
self.file=file
|
||||
self.save_every=save_every
|
||||
self.update_counter=0
|
||||
if os.path.exists(file):
|
||||
self.qtable=np.loadtxt(file)
|
||||
else:
|
||||
self.qtable=np.zeros((2**13, 4))
|
||||
|
||||
|
||||
# Setup QTable
|
||||
# Boolean features:
|
||||
# Snake go up?
|
||||
# Snake go right?
|
||||
# Snake go down?
|
||||
# Snake go left?
|
||||
# Apple at up?
|
||||
# Apple at right?
|
||||
# Apple at down?
|
||||
# Apple at left?
|
||||
# Obstacle at up?
|
||||
# Obstacle at right?
|
||||
# Obstacle at down?
|
||||
# Obstacle at left?
|
||||
# Queue in front?
|
||||
##### Totally 13 boolean features so 2^13=8192 states
|
||||
##### Totally 4 actions for the AI (up, right,down,left)
|
||||
##### Totally 4*2^13 thus 32768 table entries
|
||||
##### Reward +1 when eat an apple
|
||||
##### Reward -10 when hit obstacle
|
||||
|
||||
qtable=np.zeros((2**13, 4))
|
||||
|
||||
|
||||
|
||||
game=Snake(length=1,fps=200,startat=(10,10))
|
||||
|
||||
def isWall(h,game):
|
||||
def isWall(self,h,game):
|
||||
if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height:
|
||||
return(True)
|
||||
return(False)
|
||||
|
||||
|
||||
last_state=None
|
||||
last_action=None
|
||||
attempt=0
|
||||
def event_handler(game,event):
|
||||
global last_state,last_action,attempt
|
||||
|
||||
def get_state(self,game):
|
||||
# First compute usefull values
|
||||
h=game.snake[0]
|
||||
left=(h[0]-1,h[1])
|
||||
right=(h[0]+1,h[1])
|
||||
|
@ -63,85 +61,82 @@ def event_handler(game,event):
|
|||
apple_down=(a[1]>h[1])
|
||||
apple_left=(a[0]<h[0])
|
||||
|
||||
obstacle_up=(up in game.snake or isWall(up, game))
|
||||
obstacle_right=(right in game.snake or isWall(right, game))
|
||||
obstacle_down=(down in game.snake or isWall(down, game))
|
||||
obstacle_left=(left in game.snake or isWall(left, game))
|
||||
obstacle_up=(up in game.snake or self.isWall(up, game))
|
||||
obstacle_right=(right in game.snake or self.isWall(right, game))
|
||||
obstacle_down=(down in game.snake or self.isWall(down, game))
|
||||
obstacle_left=(left in game.snake or self.isWall(left, game))
|
||||
|
||||
queue_in_front=0
|
||||
tail_in_front=0
|
||||
if game.direction == 3:
|
||||
for x in range(h[0],game.grid_width):
|
||||
if (x,h[1]) in game.snake[1:]:
|
||||
queue_in_front=1
|
||||
tail_in_front=1
|
||||
break
|
||||
elif game.direction == 9:
|
||||
for x in range(0,h[0]):
|
||||
if (x,h[1]) in game.snake[1:]:
|
||||
queue_in_front=1
|
||||
tail_in_front=1
|
||||
break
|
||||
elif game.direction == 12:
|
||||
for y in range(0,h[1]):
|
||||
if (h[0],y) in game.snake[1:]:
|
||||
queue_in_front=1
|
||||
tail_in_front=1
|
||||
break
|
||||
elif game.direction == 6:
|
||||
for y in range(h[1],game.grid_height):
|
||||
if (h[0],y) in game.snake[1:]:
|
||||
queue_in_front=1
|
||||
tail_in_front=1
|
||||
break
|
||||
|
||||
reward=0
|
||||
if event==0:
|
||||
attempt+=1
|
||||
if event==-1:
|
||||
reward=-10
|
||||
attempt=0
|
||||
elif event==1:
|
||||
reward=5
|
||||
attempt=0
|
||||
|
||||
# This come from me I do not now if it is the best way to identify a state
|
||||
state=2**12*queue_in_front+2**11*snake_go_up+2**10*snake_go_right+2**9*snake_go_down+2**8*snake_go_left+2**7*apple_up+2**6*apple_right+2**5*apple_down+2**4*apple_left+2**3*obstacle_up+2**2*obstacle_right+2**1*obstacle_down+obstacle_left
|
||||
state=2**12*tail_in_front+2**11*snake_go_up+2**10*snake_go_right+2**9*snake_go_down+2**8*snake_go_left+2**7*apple_up+2**6*apple_right+2**5*apple_down+2**4*apple_left+2**3*obstacle_up+2**2*obstacle_right+2**1*obstacle_down+obstacle_left
|
||||
return(state)
|
||||
|
||||
def apply_bellman(self,state,action,new_state,reward):
|
||||
alpha=0.5
|
||||
gamma=0.9
|
||||
self.qtable[state,action]=self.qtable[state,action]+alpha*(reward+gamma*np.max(self.qtable[new_state])-self.qtable[state,action])
|
||||
self.update_counter+=1
|
||||
if self.update_counter>=self.save_every:
|
||||
np.savetxt(self.file,self.qtable)
|
||||
self.update_counter=0
|
||||
|
||||
def get_action(self,state):
|
||||
# Choose an action
|
||||
action=random.choice((0,1,2,3))
|
||||
if np.max(qtable[state]) > 0:
|
||||
if np.max(self.qtable[state]) > 0:
|
||||
#qactions=qtable[state]
|
||||
#options=np.flatnonzero(qactions == np.max(qactions)) # Since Q value might be equals for several actions
|
||||
#action = random.choice(options)
|
||||
action=np.argmax(qtable[state])
|
||||
action=np.argmax(self.qtable[state])
|
||||
return(action)
|
||||
|
||||
# Avoid infinite loop
|
||||
if attempt>game.grid_height*game.grid_width:
|
||||
return(-1)
|
||||
|
||||
# Update current state Q
|
||||
if last_state != None:
|
||||
qtable[last_state,last_action]=qtable[last_state,last_action]+0.7*(reward+0.9*np.max(qtable[state])-qtable[last_state,last_action])
|
||||
|
||||
|
||||
|
||||
# Perform learning
|
||||
perf=0
|
||||
last_state=None
|
||||
last_action=None
|
||||
game=Snake(length=4,fps=300,startat=(10,10))
|
||||
qtable=QTable("qtable.txt")
|
||||
|
||||
for i in range(0,10000):
|
||||
result=0
|
||||
while result >= 0:
|
||||
state=qtable.get_state(game)
|
||||
action=qtable.get_action(state)
|
||||
result=game.play3(action)
|
||||
if last_state!=None:
|
||||
reward=0
|
||||
if result==-1:
|
||||
reward=-10
|
||||
elif result==1:
|
||||
reward=1
|
||||
qtable.apply_bellman(last_state,last_action,state,reward)
|
||||
last_state=state
|
||||
last_action=action
|
||||
|
||||
# Apply the action
|
||||
snake_action=12
|
||||
if action==1:
|
||||
snake_action=3
|
||||
elif action==2:
|
||||
snake_action=6
|
||||
elif action==3:
|
||||
snake_action=9
|
||||
game.direction=snake_action
|
||||
return(0)
|
||||
|
||||
if os.path.exists("qtable.txt"):
|
||||
qtable=np.loadtxt("qtable.txt")
|
||||
|
||||
perf=0
|
||||
for i in range(0,10000):
|
||||
last_state=None
|
||||
last_action=None
|
||||
score=game.run(event_handler=event_handler)
|
||||
attempt=0
|
||||
if i%10 == 0:
|
||||
np.savetxt('qtable.txt',qtable)
|
||||
# Measurements
|
||||
score=game.last_score
|
||||
perf=max(perf,score)
|
||||
print("Game ended with "+str(score)+" best so far is "+str(perf))
|
8192
qtable.txt
Normal file
8192
qtable.txt
Normal file
File diff suppressed because it is too large
Load diff
99
snake.py
99
snake.py
|
@ -7,7 +7,7 @@ class Snake:
|
|||
Programmable Game of Snake written in PyGame
|
||||
"""
|
||||
|
||||
def __init__(self, startat=(0,0), margin=80,length=4,grid_width=30,grid_height=30, grid_pts=30,fps=180):
|
||||
def __init__(self, startat=(0,0), margin=80,length=4,grid_width=30,grid_height=30, grid_pts=30,fps=15):
|
||||
# Init attributes
|
||||
self.grid_width=grid_width
|
||||
self.grid_height=grid_height
|
||||
|
@ -17,11 +17,16 @@ class Snake:
|
|||
self.attempt=0
|
||||
self.fps=fps
|
||||
self.startat=startat
|
||||
self.last_score=-1
|
||||
# Setup pygame
|
||||
pygame.init()
|
||||
self.font=pygame.font.SysFont(pygame.font.get_default_font(), int(self.margin/2))
|
||||
self.font_small=pygame.font.SysFont(pygame.font.get_default_font(), int(self.margin/2.5))
|
||||
self.screen=pygame.display.set_mode((grid_width*grid_pts,grid_height*grid_pts+margin))
|
||||
self.clock = pygame.time.Clock()
|
||||
# Start game
|
||||
self.new_game()
|
||||
self.draw()
|
||||
|
||||
def new_game(self):
|
||||
"""
|
||||
|
@ -117,32 +122,73 @@ class Snake:
|
|||
return(True)
|
||||
return(False)
|
||||
|
||||
|
||||
def run(self, event_handler=None):
|
||||
"""
|
||||
Play a game (one attempt)
|
||||
"""
|
||||
clock = pygame.time.Clock()
|
||||
ignore_has_loose=True
|
||||
self.new_game()
|
||||
last_event=0 # 0 is nothing, 1 is eat an apple and -1 loose
|
||||
while True:
|
||||
def draw(self):
|
||||
self.screen.fill((0,0,0))
|
||||
self.draw_snake()
|
||||
self.draw_apple()
|
||||
self.draw_infos()
|
||||
# Check for loose
|
||||
if not(ignore_has_loose) and self.has_loose():
|
||||
event_handler(self,-1)
|
||||
break
|
||||
else:
|
||||
ignore_has_loose=False
|
||||
pygame.display.flip()
|
||||
|
||||
def play(self,direction):
|
||||
"""
|
||||
Play using wall clock directions (12=up, 3=right, 6=down and 9=left)
|
||||
"""
|
||||
# Play
|
||||
self.direction=direction
|
||||
self.move()
|
||||
# Return code
|
||||
code=0
|
||||
if self.apple==self.snake[0]:
|
||||
self.snake.append(self.snake[len(self.snake)-1])
|
||||
self.new_apple()
|
||||
self.score+=1
|
||||
code=1
|
||||
elif self.has_loose():
|
||||
self.last_score=self.score
|
||||
self.new_game()
|
||||
code=-1
|
||||
# Refresh screen
|
||||
self.draw()
|
||||
self.clock.tick(self.fps)
|
||||
return(code)
|
||||
|
||||
def play2(self,direction):
|
||||
"""
|
||||
Play using ascii directions
|
||||
"""
|
||||
if direction.lower()=="up":
|
||||
return(self.play(self,12))
|
||||
elif direction.lower()=="right":
|
||||
return(self.play(self,3))
|
||||
elif direction.lower()=="down":
|
||||
return(self.play(self,6))
|
||||
elif direction.lower()=="left":
|
||||
return(self.play(self,9))
|
||||
|
||||
def play3(self,direction):
|
||||
"""
|
||||
Play using 0123 directions (0=right, 1=down, 3=left and 4=up)
|
||||
"""
|
||||
if direction == 0:
|
||||
return(self.play(3))
|
||||
elif direction == 1:
|
||||
return(self.play(6))
|
||||
elif direction == 2:
|
||||
return(self.play(9))
|
||||
elif direction == 3:
|
||||
return(self.play(12))
|
||||
|
||||
def play_with_keyboard(self):
|
||||
"""
|
||||
Play a game using keyboard
|
||||
"""
|
||||
while True:
|
||||
# Check inputs
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
pygame.quit()
|
||||
sys.exit()
|
||||
elif event_handler==None and event.type == pygame.KEYDOWN:
|
||||
elif event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_LEFT and self.direction != 3:
|
||||
self.direction=9
|
||||
break
|
||||
|
@ -155,19 +201,6 @@ class Snake:
|
|||
elif event.key == pygame.K_DOWN and self.direction != 12:
|
||||
self.direction=6
|
||||
break
|
||||
# Check if an event handler is available
|
||||
if event_handler!=None:
|
||||
code=event_handler(self,last_event)
|
||||
if code < 0:
|
||||
|
||||
if self.play(self.direction) <0:
|
||||
break
|
||||
last_event=0
|
||||
self.move()
|
||||
# Check for eating apple
|
||||
if self.apple==self.snake[0]:
|
||||
self.snake.append(self.snake[len(self.snake)-1])
|
||||
self.new_apple()
|
||||
self.score+=1
|
||||
last_event=1
|
||||
pygame.display.flip()
|
||||
clock.tick(self.fps)
|
||||
return(self.score)
|
||||
|
|
Loading…
Add table
Reference in a new issue