snakeq/qlearning.py

162 lines
5 KiB
Python
Raw Normal View History

2022-11-01 14:28:00 +01:00
#!/usr/bin/env python
2022-11-02 13:28:30 +01:00
import sys,random,os,statistics
2022-11-01 14:28:00 +01:00
import numpy as np
# Import snake game
from snake import Snake
2022-11-02 08:50:05 +01:00
class QTable:
"""
# Boolean features:
# Snake go up?
# Snake go right?
# Snake go down?
# Snake go left?
# Apple at up?
# Apple at right?
# Apple at down?
# Apple at left?
# Obstacle at up?
# Obstacle at right?
# Obstacle at down?
# Obstacle at left?
2022-11-02 15:49:29 +01:00
##### Totally 12 boolean features so 2^12=4096 states
2022-11-02 08:50:05 +01:00
##### Totally 4 actions for the AI (up, right,down,left)
2022-11-02 15:49:29 +01:00
##### Totally 4*2^12 thus 16384 table entries
2022-11-02 08:50:05 +01:00
##### Reward +1 when eat an apple
##### Reward -10 when hit obstacle
"""
2022-11-02 12:18:07 +01:00
def __init__(self, file, save_every=5000):
2022-11-02 08:50:05 +01:00
self.file=file
self.save_every=save_every
2022-11-02 12:18:07 +01:00
self.save_counter=0
2022-11-02 08:50:05 +01:00
if os.path.exists(file):
self.qtable=np.loadtxt(file)
else:
2022-11-02 15:49:29 +01:00
self.qtable=np.zeros((2**12, 4))
2022-11-02 12:18:07 +01:00
with open(file+"_generation","w") as f:
f.write("0")
2022-11-02 08:50:05 +01:00
def isWall(self,h,game):
if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height:
return(True)
return(False)
def get_state(self,game):
# First compute usefull values
h=game.snake[0]
left=(h[0]-1,h[1])
right=(h[0]+1,h[1])
up=(h[0],h[1]-1)
down=(h[0],h[1]+1)
a=game.apple
snake_go_up=(game.direction==12)
snake_go_right=(game.direction==3)
snake_go_down=(game.direction==6)
snake_go_left=(game.direction==9)
apple_up=(a[1]<h[1])
apple_right=(a[0]>h[0])
apple_down=(a[1]>h[1])
apple_left=(a[0]<h[0])
obstacle_up=(up in game.snake or self.isWall(up, game))
obstacle_right=(right in game.snake or self.isWall(right, game))
obstacle_down=(down in game.snake or self.isWall(down, game))
obstacle_left=(left in game.snake or self.isWall(left, game))
# This come from me I do not now if it is the best way to identify a state
2022-11-02 09:12:06 +01:00
state=\
2**11*snake_go_up+\
2**10*snake_go_right+\
2**9*snake_go_down+\
2**8*snake_go_left+\
2**7*apple_up+\
2**6*apple_right+\
2**5*apple_down+\
2**4*apple_left+\
2**3*obstacle_up+\
2**2*obstacle_right+\
2**1*obstacle_down+\
obstacle_left
2022-11-02 08:50:05 +01:00
return(state)
def apply_bellman(self,state,action,new_state,reward):
2022-11-02 15:49:29 +01:00
alpha=0.1
gamma=0.95
2022-11-02 08:50:05 +01:00
self.qtable[state,action]=self.qtable[state,action]+alpha*(reward+gamma*np.max(self.qtable[new_state])-self.qtable[state,action])
2022-11-02 12:18:07 +01:00
self.save_counter+=1
if self.save_counter>=self.save_every:
2022-11-02 08:50:05 +01:00
np.savetxt(self.file,self.qtable)
2022-11-02 12:18:07 +01:00
if os.path.exists(self.file+"_generation"):
generation=0
with open(self.file+"_generation","r") as f:
generation=int(f.readline().rstrip())
generation+=self.save_every
with open(self.file+"_generation","w") as f:
f.write(str(generation))
2022-11-02 13:28:30 +01:00
print("----------------------------- Checkpointing generation "+str(generation))
2022-11-02 12:18:07 +01:00
self.save_counter=0
2022-11-02 08:50:05 +01:00
def get_action(self,state):
# Choose an action
action=random.choice((0,1,2,3))
if np.max(self.qtable[state]) > 0:
#qactions=qtable[state]
#options=np.flatnonzero(qactions == np.max(qactions)) # Since Q value might be equals for several actions
#action = random.choice(options)
action=np.argmax(self.qtable[state])
return(action)
2022-11-02 12:18:07 +01:00
def get_random_action(self):
return(random.choice((0,1,2,3)))
2022-11-02 08:50:05 +01:00
# Perform learning
2022-11-02 15:49:29 +01:00
width,height=80,50 # My advice is start with a small grid 5x5 to have many interaction and avoid early toy effect
2022-11-02 08:50:05 +01:00
perf=0
2022-11-02 13:28:30 +01:00
perf_list=list()
2022-11-01 17:30:25 +01:00
last_state=None
last_action=None
2022-11-02 15:49:29 +01:00
game=Snake(length=1,fps=500,grid_pts=20,startat=(random.randint(0,width-1),random.randint(0,height-1)),grid_width=width,grid_height=height)
2022-11-02 08:50:05 +01:00
qtable=QTable("qtable.txt")
2022-11-02 09:29:45 +01:00
while True:
2022-11-02 08:50:05 +01:00
result=0
2022-11-02 09:12:06 +01:00
stuck=0
2022-11-02 10:29:11 +01:00
stuck_tolerance=1
2022-11-02 09:12:06 +01:00
state=qtable.get_state(game)
2022-11-02 08:50:05 +01:00
while result >= 0:
action=qtable.get_action(state)
result=game.play3(action)
2022-11-02 09:12:06 +01:00
new_state=qtable.get_state(game)
# Compute reward and update stuck
reward=0
if result==-1:
reward=-10
stuck=0
elif result==1:
2022-11-02 12:18:07 +01:00
reward=1
2022-11-02 09:12:06 +01:00
stuck=0
2022-11-02 12:18:07 +01:00
# Agent is stuck
if stuck>=(game.grid_width*game.grid_height)/stuck_tolerance:
2022-11-02 15:49:29 +01:00
print("Stuck! Apply penality and abort!")
qtable.apply_bellman(state,action,new_state,-1)
2022-11-02 12:18:07 +01:00
game.new_game()
break
2022-11-02 09:12:06 +01:00
# Apply learning
qtable.apply_bellman(state,action,new_state,reward)
state=new_state
stuck+=1
2022-11-02 08:50:05 +01:00
# Measurements
score=game.last_score
2022-11-02 13:28:30 +01:00
perf_list.append(score)
2022-11-01 22:12:14 +01:00
perf=max(perf,score)
2022-11-02 13:28:30 +01:00
print("Game ended with "+str(score)+" best so far is "+str(perf)+ " median is "+str(statistics.median(perf_list)))