2022-11-01 14:28:00 +01:00
|
|
|
#!/usr/bin/env python
|
2022-11-01 20:21:34 +01:00
|
|
|
import sys,random,os
|
2022-11-01 14:28:00 +01:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
# Import snake game
|
|
|
|
from snake import Snake
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Setup QTable
|
2022-11-01 14:48:41 +01:00
|
|
|
# Boolean features:
|
|
|
|
# Snake go up?
|
|
|
|
# Snake go right?
|
|
|
|
# Snake go down?
|
|
|
|
# Snake go left?
|
|
|
|
# Apple at up?
|
|
|
|
# Apple at right?
|
|
|
|
# Apple at down?
|
|
|
|
# Apple at left?
|
|
|
|
# Obstacle at up?
|
|
|
|
# Obstacle at right?
|
|
|
|
# Obstacle at down?
|
|
|
|
# Obstacle at left?
|
|
|
|
##### Totally 12 boolean features so 2^12=4096 states
|
|
|
|
##### Totally 4 actions for the AI (up, right,down,left)
|
|
|
|
##### Totally 4*2^12 thus 16 384 table entries
|
2022-11-01 14:58:06 +01:00
|
|
|
##### Reward +1 when eat an apple
|
|
|
|
##### Reward -10 when hit obstacle
|
2022-11-01 14:48:41 +01:00
|
|
|
|
|
|
|
qtable=np.zeros((4096, 4))
|
2022-11-01 14:28:00 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
2022-11-01 20:21:34 +01:00
|
|
|
game=Snake(length=4,fps=200)
|
2022-11-01 14:28:00 +01:00
|
|
|
|
2022-11-01 14:48:41 +01:00
|
|
|
def isWall(h,game):
|
|
|
|
if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height:
|
|
|
|
return(True)
|
|
|
|
return(False)
|
|
|
|
|
2022-11-01 17:30:25 +01:00
|
|
|
|
|
|
|
last_state=None
|
|
|
|
last_action=None
|
2022-11-01 20:21:34 +01:00
|
|
|
attempt=0
|
2022-11-01 15:09:08 +01:00
|
|
|
def event_handler(game,event):
|
2022-11-01 20:21:34 +01:00
|
|
|
global last_state,last_action,attempt
|
2022-11-01 17:30:25 +01:00
|
|
|
|
2022-11-01 14:48:41 +01:00
|
|
|
h=game.snake[0]
|
|
|
|
left=(h[0]-1,h[1])
|
|
|
|
right=(h[0]+1,h[1])
|
|
|
|
up=(h[0],h[1]-1)
|
|
|
|
down=(h[0],h[1]+1)
|
|
|
|
a=game.apple
|
|
|
|
|
|
|
|
snake_go_up=(game.direction==12)
|
|
|
|
snake_go_right=(game.direction==3)
|
|
|
|
snake_go_down=(game.direction==6)
|
|
|
|
snake_go_left=(game.direction==9)
|
|
|
|
|
2022-11-01 20:21:34 +01:00
|
|
|
apple_up=(a[1]<h[1])
|
|
|
|
apple_right=(a[0]>h[0])
|
|
|
|
apple_down=(a[1]>h[1])
|
|
|
|
apple_left=(a[0]<h[0])
|
2022-11-01 14:48:41 +01:00
|
|
|
|
|
|
|
obstacle_up=(up in game.snake or isWall(up, game))
|
|
|
|
obstacle_right=(right in game.snake or isWall(right, game))
|
|
|
|
obstacle_down=(down in game.snake or isWall(down, game))
|
|
|
|
obstacle_left=(left in game.snake or isWall(left, game))
|
|
|
|
|
2022-11-01 17:07:09 +01:00
|
|
|
reward=0
|
2022-11-01 20:21:34 +01:00
|
|
|
if event==0:
|
|
|
|
attempt+=1
|
|
|
|
if event==-1:
|
2022-11-01 17:07:09 +01:00
|
|
|
reward=-10
|
2022-11-01 20:21:34 +01:00
|
|
|
attempt=0
|
|
|
|
elif event==1:
|
|
|
|
reward=1
|
|
|
|
attempt=0
|
|
|
|
# Avoid infinite loop
|
|
|
|
if attempt>3000:
|
|
|
|
reward=-1
|
|
|
|
attempt=0
|
2022-11-01 17:15:18 +01:00
|
|
|
# This come from me I do not now if it is the best way to identify a state
|
|
|
|
state=2**11*snake_go_up+2**10*snake_go_right+2**9*snake_go_down+2**8*snake_go_left+2**7*apple_up+2**6*apple_right+2**5*apple_down+2**4*apple_left+2**3*obstacle_up+2**2*obstacle_right+2**1*obstacle_down+obstacle_left
|
2022-11-01 17:20:36 +01:00
|
|
|
|
2022-11-01 17:30:25 +01:00
|
|
|
# Choose an action
|
2022-11-01 17:34:15 +01:00
|
|
|
action=random.choice((0,1,2,3))
|
2022-11-01 17:20:36 +01:00
|
|
|
if np.max(qtable[state]) > 0:
|
2022-11-01 20:21:34 +01:00
|
|
|
#qactions=qtable[state]
|
|
|
|
#options=np.flatnonzero(qactions == np.max(qactions)) # Since Q value might be equals for several actions
|
|
|
|
#action = random.choice(options)
|
|
|
|
action=np.argmax(qtable[state])
|
2022-11-01 17:30:25 +01:00
|
|
|
|
|
|
|
# Update current state Q
|
|
|
|
if last_state != None:
|
2022-11-01 20:21:34 +01:00
|
|
|
qtable[last_state,last_action]=qtable[last_state,last_action]+0.7*(reward+0.9*np.max(qtable[state])-qtable[last_state,last_action])
|
2022-11-01 17:30:25 +01:00
|
|
|
last_state=state
|
|
|
|
last_action=action
|
|
|
|
|
|
|
|
# Apply the action
|
|
|
|
snake_action=12
|
|
|
|
if action==1:
|
|
|
|
snake_action=3
|
|
|
|
elif action==2:
|
|
|
|
snake_action=6
|
|
|
|
elif action==3:
|
|
|
|
snake_action=9
|
|
|
|
game.direction=snake_action
|
2022-11-01 14:28:00 +01:00
|
|
|
|
2022-11-01 20:21:34 +01:00
|
|
|
if os.path.exists("qtable.txt"):
|
|
|
|
qtable=np.loadtxt("qtable.txt")
|
|
|
|
for i in range(0,10000):
|
2022-11-01 17:30:25 +01:00
|
|
|
last_state=None
|
|
|
|
last_action=None
|
2022-11-01 14:48:41 +01:00
|
|
|
score=game.run(event_handler=event_handler)
|
2022-11-01 20:21:34 +01:00
|
|
|
if i%100 == 0:
|
|
|
|
np.savetxt('qtable.txt',qtable)
|
2022-11-01 14:28:00 +01:00
|
|
|
print("Game ended with "+str(score))
|