snakeq/qlearning.py

191 lines
5.8 KiB
Python
Raw Normal View History

2022-11-01 14:28:00 +01:00
#!/usr/bin/env python
2022-11-01 20:21:34 +01:00
import sys,random,os
2022-11-01 14:28:00 +01:00
import numpy as np
# Import snake game
from snake import Snake
2022-11-02 08:50:05 +01:00
class QTable:
"""
# Boolean features:
# Snake go up?
# Snake go right?
# Snake go down?
# Snake go left?
# Apple at up?
# Apple at right?
# Apple at down?
# Apple at left?
# Obstacle at up?
# Obstacle at right?
# Obstacle at down?
# Obstacle at left?
# Tail in front?
##### Totally 13 boolean features so 2^13=8192 states
##### Totally 4 actions for the AI (up, right,down,left)
##### Totally 4*2^13 thus 32768 table entries
##### Reward +1 when eat an apple
##### Reward -10 when hit obstacle
"""
2022-11-02 12:18:07 +01:00
def __init__(self, file, save_every=5000):
2022-11-02 08:50:05 +01:00
self.file=file
self.save_every=save_every
2022-11-02 12:18:07 +01:00
self.save_counter=0
2022-11-02 08:50:05 +01:00
if os.path.exists(file):
self.qtable=np.loadtxt(file)
else:
self.qtable=np.zeros((2**13, 4))
2022-11-02 12:18:07 +01:00
with open(file+"_generation","w") as f:
f.write("0")
2022-11-02 08:50:05 +01:00
def isWall(self,h,game):
if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height:
return(True)
return(False)
def get_state(self,game):
# First compute usefull values
h=game.snake[0]
left=(h[0]-1,h[1])
right=(h[0]+1,h[1])
up=(h[0],h[1]-1)
down=(h[0],h[1]+1)
a=game.apple
snake_go_up=(game.direction==12)
snake_go_right=(game.direction==3)
snake_go_down=(game.direction==6)
snake_go_left=(game.direction==9)
apple_up=(a[1]<h[1])
apple_right=(a[0]>h[0])
apple_down=(a[1]>h[1])
apple_left=(a[0]<h[0])
obstacle_up=(up in game.snake or self.isWall(up, game))
obstacle_right=(right in game.snake or self.isWall(right, game))
obstacle_down=(down in game.snake or self.isWall(down, game))
obstacle_left=(left in game.snake or self.isWall(left, game))
tail_in_front=0
2022-11-02 09:12:06 +01:00
if snake_go_right:
2022-11-02 08:50:05 +01:00
for x in range(h[0],game.grid_width):
if (x,h[1]) in game.snake[1:]:
tail_in_front=1
break
2022-11-02 09:12:06 +01:00
elif snake_go_left:
2022-11-02 08:50:05 +01:00
for x in range(0,h[0]):
if (x,h[1]) in game.snake[1:]:
tail_in_front=1
break
2022-11-02 09:12:06 +01:00
elif snake_go_up:
2022-11-02 08:50:05 +01:00
for y in range(0,h[1]):
if (h[0],y) in game.snake[1:]:
tail_in_front=1
break
2022-11-02 09:12:06 +01:00
elif snake_go_down:
2022-11-02 08:50:05 +01:00
for y in range(h[1],game.grid_height):
if (h[0],y) in game.snake[1:]:
tail_in_front=1
break
# This come from me I do not now if it is the best way to identify a state
2022-11-02 09:12:06 +01:00
state=\
2**12*tail_in_front+\
2**11*snake_go_up+\
2**10*snake_go_right+\
2**9*snake_go_down+\
2**8*snake_go_left+\
2**7*apple_up+\
2**6*apple_right+\
2**5*apple_down+\
2**4*apple_left+\
2**3*obstacle_up+\
2**2*obstacle_right+\
2**1*obstacle_down+\
obstacle_left
2022-11-02 08:50:05 +01:00
return(state)
def apply_bellman(self,state,action,new_state,reward):
alpha=0.5
gamma=0.9
self.qtable[state,action]=self.qtable[state,action]+alpha*(reward+gamma*np.max(self.qtable[new_state])-self.qtable[state,action])
2022-11-02 12:18:07 +01:00
self.save_counter+=1
if self.save_counter>=self.save_every:
2022-11-02 08:50:05 +01:00
np.savetxt(self.file,self.qtable)
2022-11-02 12:18:07 +01:00
if os.path.exists(self.file+"_generation"):
generation=0
with open(self.file+"_generation","r") as f:
generation=int(f.readline().rstrip())
generation+=self.save_every
with open(self.file+"_generation","w") as f:
f.write(str(generation))
print("Checkpointing generation "+str(generation))
self.save_counter=0
2022-11-02 08:50:05 +01:00
def get_action(self,state):
# Choose an action
action=random.choice((0,1,2,3))
if np.max(self.qtable[state]) > 0:
#qactions=qtable[state]
#options=np.flatnonzero(qactions == np.max(qactions)) # Since Q value might be equals for several actions
#action = random.choice(options)
action=np.argmax(self.qtable[state])
return(action)
2022-11-02 12:18:07 +01:00
def get_random_action(self):
return(random.choice((0,1,2,3)))
2022-11-02 08:50:05 +01:00
# Perform learning
2022-11-02 12:20:24 +01:00
width,height=10,10
2022-11-02 08:50:05 +01:00
perf=0
2022-11-01 17:30:25 +01:00
last_state=None
last_action=None
2022-11-02 10:29:11 +01:00
game=Snake(length=1,fps=500,startat=(random.randint(0,width-1),random.randint(0,height-1)),grid_width=width,grid_height=height)
2022-11-02 08:50:05 +01:00
qtable=QTable("qtable.txt")
2022-11-02 09:29:45 +01:00
while True:
2022-11-02 08:50:05 +01:00
result=0
2022-11-02 09:12:06 +01:00
stuck=0
2022-11-02 10:29:11 +01:00
stuck_tolerance=1
2022-11-02 12:18:07 +01:00
stuck_count=0
2022-11-02 09:12:06 +01:00
state=qtable.get_state(game)
2022-11-02 08:50:05 +01:00
while result >= 0:
action=qtable.get_action(state)
result=game.play3(action)
2022-11-02 09:12:06 +01:00
new_state=qtable.get_state(game)
# Compute reward and update stuck
reward=0
if result==-1:
reward=-10
stuck=0
2022-11-02 12:18:07 +01:00
stuck_count=0
2022-11-02 09:12:06 +01:00
elif result==1:
2022-11-02 12:18:07 +01:00
reward=1
2022-11-02 09:12:06 +01:00
stuck=0
2022-11-02 12:18:07 +01:00
stuck_count=0
# Agent is stuck
if stuck>=(game.grid_width*game.grid_height)/stuck_tolerance:
stuck=0
stuck_count+=1
action=qtable.get_random_action()
2022-11-02 12:20:24 +01:00
print("Stuck! Try a random action...")
2022-11-02 12:18:07 +01:00
if stuck_count>2:
2022-11-02 12:20:24 +01:00
print("Can't get out of stuck. Abort!")
2022-11-02 12:18:07 +01:00
stuck_count=0
game.new_game()
break
2022-11-02 09:12:06 +01:00
# Apply learning
qtable.apply_bellman(state,action,new_state,reward)
state=new_state
stuck+=1
2022-11-02 08:50:05 +01:00
# Measurements
score=game.last_score
2022-11-01 22:12:14 +01:00
perf=max(perf,score)
print("Game ended with "+str(score)+" best so far is "+str(perf))