Minor changes
This commit is contained in:
parent
3b35b6866d
commit
a1469f368b
3 changed files with 4971 additions and 843 deletions
51
qlearning.py
51
qlearning.py
|
@ -21,17 +21,18 @@ from snake import Snake
|
||||||
# Obstacle at right?
|
# Obstacle at right?
|
||||||
# Obstacle at down?
|
# Obstacle at down?
|
||||||
# Obstacle at left?
|
# Obstacle at left?
|
||||||
##### Totally 12 boolean features so 2^12=4096 states
|
# Queue in front?
|
||||||
|
##### Totally 13 boolean features so 2^13=8192 states
|
||||||
##### Totally 4 actions for the AI (up, right,down,left)
|
##### Totally 4 actions for the AI (up, right,down,left)
|
||||||
##### Totally 4*2^12 thus 16 384 table entries
|
##### Totally 4*2^13 thus 32768 table entries
|
||||||
##### Reward +1 when eat an apple
|
##### Reward +1 when eat an apple
|
||||||
##### Reward -10 when hit obstacle
|
##### Reward -10 when hit obstacle
|
||||||
|
|
||||||
qtable=np.zeros((4096, 4))
|
qtable=np.zeros((2**13, 4))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
game=Snake(length=4,fps=200)
|
game=Snake(length=4,fps=200,startat=(10,10))
|
||||||
|
|
||||||
def isWall(h,game):
|
def isWall(h,game):
|
||||||
if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height:
|
if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height:
|
||||||
|
@ -67,6 +68,28 @@ def event_handler(game,event):
|
||||||
obstacle_down=(down in game.snake or isWall(down, game))
|
obstacle_down=(down in game.snake or isWall(down, game))
|
||||||
obstacle_left=(left in game.snake or isWall(left, game))
|
obstacle_left=(left in game.snake or isWall(left, game))
|
||||||
|
|
||||||
|
queue_in_front=0
|
||||||
|
if game.direction == 3:
|
||||||
|
for x in range(h[0],game.grid_width):
|
||||||
|
if (x,h[1]) in game.snake[1:]:
|
||||||
|
queue_in_front=1
|
||||||
|
break
|
||||||
|
elif game.direction == 9:
|
||||||
|
for x in range(0,h[0]):
|
||||||
|
if (x,h[1]) in game.snake[1:]:
|
||||||
|
queue_in_front=1
|
||||||
|
break
|
||||||
|
elif game.direction == 12:
|
||||||
|
for y in range(0,h[1]):
|
||||||
|
if (h[0],y) in game.snake[1:]:
|
||||||
|
queue_in_front=1
|
||||||
|
break
|
||||||
|
elif game.direction == 6:
|
||||||
|
for y in range(h[1],game.grid_height):
|
||||||
|
if (h[0],y) in game.snake[1:]:
|
||||||
|
queue_in_front=1
|
||||||
|
break
|
||||||
|
|
||||||
reward=0
|
reward=0
|
||||||
if event==0:
|
if event==0:
|
||||||
attempt+=1
|
attempt+=1
|
||||||
|
@ -76,12 +99,9 @@ def event_handler(game,event):
|
||||||
elif event==1:
|
elif event==1:
|
||||||
reward=1
|
reward=1
|
||||||
attempt=0
|
attempt=0
|
||||||
# Avoid infinite loop
|
|
||||||
if attempt>3000:
|
|
||||||
reward=-1
|
|
||||||
attempt=0
|
|
||||||
# This come from me I do not now if it is the best way to identify a state
|
# This come from me I do not now if it is the best way to identify a state
|
||||||
state=2**11*snake_go_up+2**10*snake_go_right+2**9*snake_go_down+2**8*snake_go_left+2**7*apple_up+2**6*apple_right+2**5*apple_down+2**4*apple_left+2**3*obstacle_up+2**2*obstacle_right+2**1*obstacle_down+obstacle_left
|
state=2**12*queue_in_front+2**11*snake_go_up+2**10*snake_go_right+2**9*snake_go_down+2**8*snake_go_left+2**7*apple_up+2**6*apple_right+2**5*apple_down+2**4*apple_left+2**3*obstacle_up+2**2*obstacle_right+2**1*obstacle_down+obstacle_left
|
||||||
|
|
||||||
# Choose an action
|
# Choose an action
|
||||||
action=random.choice((0,1,2,3))
|
action=random.choice((0,1,2,3))
|
||||||
|
@ -91,6 +111,10 @@ def event_handler(game,event):
|
||||||
#action = random.choice(options)
|
#action = random.choice(options)
|
||||||
action=np.argmax(qtable[state])
|
action=np.argmax(qtable[state])
|
||||||
|
|
||||||
|
# Avoid infinite loop
|
||||||
|
if attempt>game.grid_height*game.grid_width:
|
||||||
|
return(-1)
|
||||||
|
|
||||||
# Update current state Q
|
# Update current state Q
|
||||||
if last_state != None:
|
if last_state != None:
|
||||||
qtable[last_state,last_action]=qtable[last_state,last_action]+0.7*(reward+0.9*np.max(qtable[state])-qtable[last_state,last_action])
|
qtable[last_state,last_action]=qtable[last_state,last_action]+0.7*(reward+0.9*np.max(qtable[state])-qtable[last_state,last_action])
|
||||||
|
@ -106,13 +130,18 @@ def event_handler(game,event):
|
||||||
elif action==3:
|
elif action==3:
|
||||||
snake_action=9
|
snake_action=9
|
||||||
game.direction=snake_action
|
game.direction=snake_action
|
||||||
|
return(0)
|
||||||
|
|
||||||
if os.path.exists("qtable.txt"):
|
if os.path.exists("qtable.txt"):
|
||||||
qtable=np.loadtxt("qtable.txt")
|
qtable=np.loadtxt("qtable.txt")
|
||||||
|
|
||||||
|
perf=0
|
||||||
for i in range(0,10000):
|
for i in range(0,10000):
|
||||||
last_state=None
|
last_state=None
|
||||||
last_action=None
|
last_action=None
|
||||||
score=game.run(event_handler=event_handler)
|
score=game.run(event_handler=event_handler)
|
||||||
if i%100 == 0:
|
attempt=0
|
||||||
|
if i%10 == 0:
|
||||||
np.savetxt('qtable.txt',qtable)
|
np.savetxt('qtable.txt',qtable)
|
||||||
print("Game ended with "+str(score))
|
perf=max(perf,score)
|
||||||
|
print("Game ended with "+str(score)+" best so far is "+str(perf))
|
5754
qtable.txt
5754
qtable.txt
File diff suppressed because it is too large
Load diff
9
snake.py
9
snake.py
|
@ -7,7 +7,7 @@ class Snake:
|
||||||
Programmable Game of Snake written in PyGame
|
Programmable Game of Snake written in PyGame
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, margin=80,length=4,grid_width=30,grid_height=30, grid_pts=30,fps=80):
|
def __init__(self, startat=(0,0), margin=80,length=4,grid_width=30,grid_height=30, grid_pts=30,fps=80):
|
||||||
# Init attributes
|
# Init attributes
|
||||||
self.grid_width=grid_width
|
self.grid_width=grid_width
|
||||||
self.grid_height=grid_height
|
self.grid_height=grid_height
|
||||||
|
@ -16,6 +16,7 @@ class Snake:
|
||||||
self.default_length=length
|
self.default_length=length
|
||||||
self.attempt=0
|
self.attempt=0
|
||||||
self.fps=fps
|
self.fps=fps
|
||||||
|
self.startat=startat
|
||||||
# Setup pygame
|
# Setup pygame
|
||||||
pygame.init()
|
pygame.init()
|
||||||
self.font=pygame.font.SysFont(pygame.font.get_default_font(), int(self.margin/2))
|
self.font=pygame.font.SysFont(pygame.font.get_default_font(), int(self.margin/2))
|
||||||
|
@ -26,7 +27,7 @@ class Snake:
|
||||||
"""
|
"""
|
||||||
Reset game state
|
Reset game state
|
||||||
"""
|
"""
|
||||||
self.snake=[(0,0)]*self.default_length
|
self.snake=[self.startat]*self.default_length
|
||||||
self.direction=3 # Like clock (12=up, 3=right, 6=bottom, 9=left)
|
self.direction=3 # Like clock (12=up, 3=right, 6=bottom, 9=left)
|
||||||
self.new_apple()
|
self.new_apple()
|
||||||
self.score=0
|
self.score=0
|
||||||
|
@ -156,7 +157,9 @@ class Snake:
|
||||||
break
|
break
|
||||||
# Check if an event handler is available
|
# Check if an event handler is available
|
||||||
if event_handler!=None:
|
if event_handler!=None:
|
||||||
event_handler(self,last_event)
|
code=event_handler(self,last_event)
|
||||||
|
if code < 0:
|
||||||
|
break
|
||||||
last_event=0
|
last_event=0
|
||||||
self.move()
|
self.move()
|
||||||
# Check for eating apple
|
# Check for eating apple
|
||||||
|
|
Loading…
Add table
Reference in a new issue