162 lines
No EOL
5 KiB
Python
Executable file
162 lines
No EOL
5 KiB
Python
Executable file
#!/usr/bin/env python
|
|
import sys,random,os,statistics
|
|
import numpy as np
|
|
|
|
# Import snake game
|
|
from snake import Snake
|
|
|
|
class QTable:
|
|
"""
|
|
# Boolean features:
|
|
# Snake go up?
|
|
# Snake go right?
|
|
# Snake go down?
|
|
# Snake go left?
|
|
# Apple at up?
|
|
# Apple at right?
|
|
# Apple at down?
|
|
# Apple at left?
|
|
# Obstacle at up?
|
|
# Obstacle at right?
|
|
# Obstacle at down?
|
|
# Obstacle at left?
|
|
##### Totally 12 boolean features so 2^12=4096 states
|
|
##### Totally 4 actions for the AI (up, right,down,left)
|
|
##### Totally 4*2^12 thus 16384 table entries
|
|
##### Reward +1 when eat an apple
|
|
##### Reward -10 when hit obstacle
|
|
"""
|
|
def __init__(self, file, save_every=5000):
|
|
self.file=file
|
|
self.save_every=save_every
|
|
self.save_counter=0
|
|
if os.path.exists(file):
|
|
self.qtable=np.loadtxt(file)
|
|
else:
|
|
self.qtable=np.zeros((2**12, 4))
|
|
with open(file+"_generation","w") as f:
|
|
f.write("0")
|
|
|
|
def isWall(self,h,game):
|
|
if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height:
|
|
return(True)
|
|
return(False)
|
|
|
|
def get_state(self,game):
|
|
# First compute usefull values
|
|
h=game.snake[0]
|
|
left=(h[0]-1,h[1])
|
|
right=(h[0]+1,h[1])
|
|
up=(h[0],h[1]-1)
|
|
down=(h[0],h[1]+1)
|
|
a=game.apple
|
|
|
|
snake_go_up=(game.direction==12)
|
|
snake_go_right=(game.direction==3)
|
|
snake_go_down=(game.direction==6)
|
|
snake_go_left=(game.direction==9)
|
|
|
|
apple_up=(a[1]<h[1])
|
|
apple_right=(a[0]>h[0])
|
|
apple_down=(a[1]>h[1])
|
|
apple_left=(a[0]<h[0])
|
|
|
|
obstacle_up=(up in game.snake or self.isWall(up, game))
|
|
obstacle_right=(right in game.snake or self.isWall(right, game))
|
|
obstacle_down=(down in game.snake or self.isWall(down, game))
|
|
obstacle_left=(left in game.snake or self.isWall(left, game))
|
|
|
|
# This come from me I do not now if it is the best way to identify a state
|
|
state=\
|
|
2**11*snake_go_up+\
|
|
2**10*snake_go_right+\
|
|
2**9*snake_go_down+\
|
|
2**8*snake_go_left+\
|
|
2**7*apple_up+\
|
|
2**6*apple_right+\
|
|
2**5*apple_down+\
|
|
2**4*apple_left+\
|
|
2**3*obstacle_up+\
|
|
2**2*obstacle_right+\
|
|
2**1*obstacle_down+\
|
|
obstacle_left
|
|
return(state)
|
|
|
|
def apply_bellman(self,state,action,new_state,reward):
|
|
alpha=0.1
|
|
gamma=0.95
|
|
self.qtable[state,action]=self.qtable[state,action]+alpha*(reward+gamma*np.max(self.qtable[new_state])-self.qtable[state,action])
|
|
self.save_counter+=1
|
|
if self.save_counter>=self.save_every:
|
|
np.savetxt(self.file,self.qtable)
|
|
if os.path.exists(self.file+"_generation"):
|
|
generation=0
|
|
with open(self.file+"_generation","r") as f:
|
|
generation=int(f.readline().rstrip())
|
|
generation+=self.save_every
|
|
with open(self.file+"_generation","w") as f:
|
|
f.write(str(generation))
|
|
print("----------------------------- Checkpointing generation "+str(generation))
|
|
self.save_counter=0
|
|
|
|
def get_action(self,state):
|
|
# Choose an action
|
|
action=random.choice((0,1,2,3))
|
|
if np.max(self.qtable[state]) > 0:
|
|
#qactions=qtable[state]
|
|
#options=np.flatnonzero(qactions == np.max(qactions)) # Since Q value might be equals for several actions
|
|
#action = random.choice(options)
|
|
action=np.argmax(self.qtable[state])
|
|
return(action)
|
|
|
|
def get_random_action(self):
|
|
return(random.choice((0,1,2,3)))
|
|
|
|
|
|
|
|
|
|
|
|
# Perform learning
|
|
width,height=80,50 # My advice is start with a small grid 5x5 to have many interaction and avoid early toy effect
|
|
perf=0
|
|
perf_list=list()
|
|
last_state=None
|
|
last_action=None
|
|
game=Snake(length=1,fps=500,grid_pts=20,startat=(random.randint(0,width-1),random.randint(0,height-1)),grid_width=width,grid_height=height)
|
|
qtable=QTable("qtable.txt")
|
|
while True:
|
|
result=0
|
|
stuck=0
|
|
stuck_tolerance=1
|
|
state=qtable.get_state(game)
|
|
while result >= 0:
|
|
action=qtable.get_action(state)
|
|
result=game.play3(action)
|
|
new_state=qtable.get_state(game)
|
|
|
|
# Compute reward and update stuck
|
|
reward=0
|
|
if result==-1:
|
|
reward=-10
|
|
stuck=0
|
|
elif result==1:
|
|
reward=1
|
|
stuck=0
|
|
|
|
# Agent is stuck
|
|
if stuck>=(game.grid_width*game.grid_height)/stuck_tolerance:
|
|
print("Stuck! Apply penality and abort!")
|
|
qtable.apply_bellman(state,action,new_state,-1)
|
|
game.new_game()
|
|
break
|
|
|
|
# Apply learning
|
|
qtable.apply_bellman(state,action,new_state,reward)
|
|
state=new_state
|
|
stuck+=1
|
|
|
|
# Measurements
|
|
score=game.last_score
|
|
perf_list.append(score)
|
|
perf=max(perf,score)
|
|
print("Game ended with "+str(score)+" best so far is "+str(perf)+ " median is "+str(statistics.median(perf_list))) |