Minor changes
This commit is contained in:
parent
01db5773d5
commit
c5da8a1b7b
3 changed files with 378 additions and 4505 deletions
49
qlearning.py
49
qlearning.py
|
@ -20,10 +20,9 @@ class QTable:
|
||||||
# Obstacle at right?
|
# Obstacle at right?
|
||||||
# Obstacle at down?
|
# Obstacle at down?
|
||||||
# Obstacle at left?
|
# Obstacle at left?
|
||||||
# Tail in front?
|
##### Totally 12 boolean features so 2^12=4096 states
|
||||||
##### Totally 13 boolean features so 2^13=8192 states
|
|
||||||
##### Totally 4 actions for the AI (up, right,down,left)
|
##### Totally 4 actions for the AI (up, right,down,left)
|
||||||
##### Totally 4*2^13 thus 32768 table entries
|
##### Totally 4*2^12 thus 16384 table entries
|
||||||
##### Reward +1 when eat an apple
|
##### Reward +1 when eat an apple
|
||||||
##### Reward -10 when hit obstacle
|
##### Reward -10 when hit obstacle
|
||||||
"""
|
"""
|
||||||
|
@ -34,7 +33,7 @@ class QTable:
|
||||||
if os.path.exists(file):
|
if os.path.exists(file):
|
||||||
self.qtable=np.loadtxt(file)
|
self.qtable=np.loadtxt(file)
|
||||||
else:
|
else:
|
||||||
self.qtable=np.zeros((2**13, 4))
|
self.qtable=np.zeros((2**12, 4))
|
||||||
with open(file+"_generation","w") as f:
|
with open(file+"_generation","w") as f:
|
||||||
f.write("0")
|
f.write("0")
|
||||||
|
|
||||||
|
@ -67,30 +66,8 @@ class QTable:
|
||||||
obstacle_down=(down in game.snake or self.isWall(down, game))
|
obstacle_down=(down in game.snake or self.isWall(down, game))
|
||||||
obstacle_left=(left in game.snake or self.isWall(left, game))
|
obstacle_left=(left in game.snake or self.isWall(left, game))
|
||||||
|
|
||||||
tail_in_front=0
|
|
||||||
if snake_go_right:
|
|
||||||
for x in range(h[0],game.grid_width):
|
|
||||||
if (x,h[1]) in game.snake[1:]:
|
|
||||||
tail_in_front=1
|
|
||||||
break
|
|
||||||
elif snake_go_left:
|
|
||||||
for x in range(0,h[0]):
|
|
||||||
if (x,h[1]) in game.snake[1:]:
|
|
||||||
tail_in_front=1
|
|
||||||
break
|
|
||||||
elif snake_go_up:
|
|
||||||
for y in range(0,h[1]):
|
|
||||||
if (h[0],y) in game.snake[1:]:
|
|
||||||
tail_in_front=1
|
|
||||||
break
|
|
||||||
elif snake_go_down:
|
|
||||||
for y in range(h[1],game.grid_height):
|
|
||||||
if (h[0],y) in game.snake[1:]:
|
|
||||||
tail_in_front=1
|
|
||||||
break
|
|
||||||
# This come from me I do not now if it is the best way to identify a state
|
# This come from me I do not now if it is the best way to identify a state
|
||||||
state=\
|
state=\
|
||||||
2**12*tail_in_front+\
|
|
||||||
2**11*snake_go_up+\
|
2**11*snake_go_up+\
|
||||||
2**10*snake_go_right+\
|
2**10*snake_go_right+\
|
||||||
2**9*snake_go_down+\
|
2**9*snake_go_down+\
|
||||||
|
@ -106,8 +83,8 @@ class QTable:
|
||||||
return(state)
|
return(state)
|
||||||
|
|
||||||
def apply_bellman(self,state,action,new_state,reward):
|
def apply_bellman(self,state,action,new_state,reward):
|
||||||
alpha=0.5
|
alpha=0.1
|
||||||
gamma=0.9
|
gamma=0.95
|
||||||
self.qtable[state,action]=self.qtable[state,action]+alpha*(reward+gamma*np.max(self.qtable[new_state])-self.qtable[state,action])
|
self.qtable[state,action]=self.qtable[state,action]+alpha*(reward+gamma*np.max(self.qtable[new_state])-self.qtable[state,action])
|
||||||
self.save_counter+=1
|
self.save_counter+=1
|
||||||
if self.save_counter>=self.save_every:
|
if self.save_counter>=self.save_every:
|
||||||
|
@ -140,18 +117,17 @@ class QTable:
|
||||||
|
|
||||||
|
|
||||||
# Perform learning
|
# Perform learning
|
||||||
width,height=40,30 # My advice is start with a small grid 5x5 to have many interaction and avoid early toy effect
|
width,height=80,50 # My advice is start with a small grid 5x5 to have many interaction and avoid early toy effect
|
||||||
perf=0
|
perf=0
|
||||||
perf_list=list()
|
perf_list=list()
|
||||||
last_state=None
|
last_state=None
|
||||||
last_action=None
|
last_action=None
|
||||||
game=Snake(length=1,fps=500,startat=(random.randint(0,width-1),random.randint(0,height-1)),grid_width=width,grid_height=height)
|
game=Snake(length=1,fps=500,grid_pts=20,startat=(random.randint(0,width-1),random.randint(0,height-1)),grid_width=width,grid_height=height)
|
||||||
qtable=QTable("qtable.txt")
|
qtable=QTable("qtable.txt")
|
||||||
while True:
|
while True:
|
||||||
result=0
|
result=0
|
||||||
stuck=0
|
stuck=0
|
||||||
stuck_tolerance=1
|
stuck_tolerance=1
|
||||||
stuck_count=0
|
|
||||||
state=qtable.get_state(game)
|
state=qtable.get_state(game)
|
||||||
while result >= 0:
|
while result >= 0:
|
||||||
action=qtable.get_action(state)
|
action=qtable.get_action(state)
|
||||||
|
@ -163,21 +139,14 @@ while True:
|
||||||
if result==-1:
|
if result==-1:
|
||||||
reward=-10
|
reward=-10
|
||||||
stuck=0
|
stuck=0
|
||||||
stuck_count=0
|
|
||||||
elif result==1:
|
elif result==1:
|
||||||
reward=1
|
reward=1
|
||||||
stuck=0
|
stuck=0
|
||||||
stuck_count=0
|
|
||||||
|
|
||||||
# Agent is stuck
|
# Agent is stuck
|
||||||
if stuck>=(game.grid_width*game.grid_height)/stuck_tolerance:
|
if stuck>=(game.grid_width*game.grid_height)/stuck_tolerance:
|
||||||
stuck=0
|
print("Stuck! Apply penality and abort!")
|
||||||
stuck_count+=1
|
qtable.apply_bellman(state,action,new_state,-1)
|
||||||
game.new_apple()
|
|
||||||
print("Stuck! Try with a new apple...")
|
|
||||||
if stuck_count>2:
|
|
||||||
print("Can't get out of stuck. Abort!")
|
|
||||||
stuck_count=0
|
|
||||||
game.new_game()
|
game.new_game()
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
4832
qtable.txt
4832
qtable.txt
File diff suppressed because it is too large
Load diff
|
@ -1 +1 @@
|
||||||
400000
|
395000
|
Loading…
Add table
Reference in a new issue