snakeq/qlearning.py
2022-11-02 13:28:30 +01:00

193 lines
No EOL
6 KiB
Python
Executable file

#!/usr/bin/env python
import sys,random,os,statistics
import numpy as np
# Import snake game
from snake import Snake
class QTable:
"""
# Boolean features:
# Snake go up?
# Snake go right?
# Snake go down?
# Snake go left?
# Apple at up?
# Apple at right?
# Apple at down?
# Apple at left?
# Obstacle at up?
# Obstacle at right?
# Obstacle at down?
# Obstacle at left?
# Tail in front?
##### Totally 13 boolean features so 2^13=8192 states
##### Totally 4 actions for the AI (up, right,down,left)
##### Totally 4*2^13 thus 32768 table entries
##### Reward +1 when eat an apple
##### Reward -10 when hit obstacle
"""
def __init__(self, file, save_every=5000):
self.file=file
self.save_every=save_every
self.save_counter=0
if os.path.exists(file):
self.qtable=np.loadtxt(file)
else:
self.qtable=np.zeros((2**13, 4))
with open(file+"_generation","w") as f:
f.write("0")
def isWall(self,h,game):
if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height:
return(True)
return(False)
def get_state(self,game):
# First compute usefull values
h=game.snake[0]
left=(h[0]-1,h[1])
right=(h[0]+1,h[1])
up=(h[0],h[1]-1)
down=(h[0],h[1]+1)
a=game.apple
snake_go_up=(game.direction==12)
snake_go_right=(game.direction==3)
snake_go_down=(game.direction==6)
snake_go_left=(game.direction==9)
apple_up=(a[1]<h[1])
apple_right=(a[0]>h[0])
apple_down=(a[1]>h[1])
apple_left=(a[0]<h[0])
obstacle_up=(up in game.snake or self.isWall(up, game))
obstacle_right=(right in game.snake or self.isWall(right, game))
obstacle_down=(down in game.snake or self.isWall(down, game))
obstacle_left=(left in game.snake or self.isWall(left, game))
tail_in_front=0
if snake_go_right:
for x in range(h[0],game.grid_width):
if (x,h[1]) in game.snake[1:]:
tail_in_front=1
break
elif snake_go_left:
for x in range(0,h[0]):
if (x,h[1]) in game.snake[1:]:
tail_in_front=1
break
elif snake_go_up:
for y in range(0,h[1]):
if (h[0],y) in game.snake[1:]:
tail_in_front=1
break
elif snake_go_down:
for y in range(h[1],game.grid_height):
if (h[0],y) in game.snake[1:]:
tail_in_front=1
break
# This come from me I do not now if it is the best way to identify a state
state=\
2**12*tail_in_front+\
2**11*snake_go_up+\
2**10*snake_go_right+\
2**9*snake_go_down+\
2**8*snake_go_left+\
2**7*apple_up+\
2**6*apple_right+\
2**5*apple_down+\
2**4*apple_left+\
2**3*obstacle_up+\
2**2*obstacle_right+\
2**1*obstacle_down+\
obstacle_left
return(state)
def apply_bellman(self,state,action,new_state,reward):
alpha=0.5
gamma=0.9
self.qtable[state,action]=self.qtable[state,action]+alpha*(reward+gamma*np.max(self.qtable[new_state])-self.qtable[state,action])
self.save_counter+=1
if self.save_counter>=self.save_every:
np.savetxt(self.file,self.qtable)
if os.path.exists(self.file+"_generation"):
generation=0
with open(self.file+"_generation","r") as f:
generation=int(f.readline().rstrip())
generation+=self.save_every
with open(self.file+"_generation","w") as f:
f.write(str(generation))
print("----------------------------- Checkpointing generation "+str(generation))
self.save_counter=0
def get_action(self,state):
# Choose an action
action=random.choice((0,1,2,3))
if np.max(self.qtable[state]) > 0:
#qactions=qtable[state]
#options=np.flatnonzero(qactions == np.max(qactions)) # Since Q value might be equals for several actions
#action = random.choice(options)
action=np.argmax(self.qtable[state])
return(action)
def get_random_action(self):
return(random.choice((0,1,2,3)))
# Perform learning
width,height=40,30 # My advice is start with a small grid 5x5 to have many interaction and avoid early toy effect
perf=0
perf_list=list()
last_state=None
last_action=None
game=Snake(length=1,fps=500,startat=(random.randint(0,width-1),random.randint(0,height-1)),grid_width=width,grid_height=height)
qtable=QTable("qtable.txt")
while True:
result=0
stuck=0
stuck_tolerance=1
stuck_count=0
state=qtable.get_state(game)
while result >= 0:
action=qtable.get_action(state)
result=game.play3(action)
new_state=qtable.get_state(game)
# Compute reward and update stuck
reward=0
if result==-1:
reward=-10
stuck=0
stuck_count=0
elif result==1:
reward=1
stuck=0
stuck_count=0
# Agent is stuck
if stuck>=(game.grid_width*game.grid_height)/stuck_tolerance:
stuck=0
stuck_count+=1
game.new_apple()
print("Stuck! Try with a new apple...")
if stuck_count>2:
print("Can't get out of stuck. Abort!")
stuck_count=0
game.new_game()
break
# Apply learning
qtable.apply_bellman(state,action,new_state,reward)
state=new_state
stuck+=1
# Measurements
score=game.last_score
perf_list.append(score)
perf=max(perf,score)
print("Game ended with "+str(score)+" best so far is "+str(perf)+ " median is "+str(statistics.median(perf_list)))