Minor changes
This commit is contained in:
parent
65280331d3
commit
657578547f
2 changed files with 39 additions and 14 deletions
47
qlearning.py
47
qlearning.py
|
@ -27,15 +27,16 @@ class QTable:
|
||||||
##### Reward +1 when eat an apple
|
##### Reward +1 when eat an apple
|
||||||
##### Reward -10 when hit obstacle
|
##### Reward -10 when hit obstacle
|
||||||
"""
|
"""
|
||||||
def __init__(self, file, save_every=10):
|
def __init__(self, file, save_every=5000):
|
||||||
self.file=file
|
self.file=file
|
||||||
self.save_every=save_every
|
self.save_every=save_every
|
||||||
self.update_counter=0
|
self.save_counter=0
|
||||||
if os.path.exists(file):
|
if os.path.exists(file):
|
||||||
self.qtable=np.loadtxt(file)
|
self.qtable=np.loadtxt(file)
|
||||||
else:
|
else:
|
||||||
self.qtable=np.zeros((2**13, 4))
|
self.qtable=np.zeros((2**13, 4))
|
||||||
|
with open(file+"_generation","w") as f:
|
||||||
|
f.write("0")
|
||||||
|
|
||||||
def isWall(self,h,game):
|
def isWall(self,h,game):
|
||||||
if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height:
|
if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height:
|
||||||
|
@ -108,10 +109,18 @@ class QTable:
|
||||||
alpha=0.5
|
alpha=0.5
|
||||||
gamma=0.9
|
gamma=0.9
|
||||||
self.qtable[state,action]=self.qtable[state,action]+alpha*(reward+gamma*np.max(self.qtable[new_state])-self.qtable[state,action])
|
self.qtable[state,action]=self.qtable[state,action]+alpha*(reward+gamma*np.max(self.qtable[new_state])-self.qtable[state,action])
|
||||||
self.update_counter+=1
|
self.save_counter+=1
|
||||||
if self.update_counter>=self.save_every:
|
if self.save_counter>=self.save_every:
|
||||||
np.savetxt(self.file,self.qtable)
|
np.savetxt(self.file,self.qtable)
|
||||||
self.update_counter=0
|
if os.path.exists(self.file+"_generation"):
|
||||||
|
generation=0
|
||||||
|
with open(self.file+"_generation","r") as f:
|
||||||
|
generation=int(f.readline().rstrip())
|
||||||
|
generation+=self.save_every
|
||||||
|
with open(self.file+"_generation","w") as f:
|
||||||
|
f.write(str(generation))
|
||||||
|
print("Checkpointing generation "+str(generation))
|
||||||
|
self.save_counter=0
|
||||||
|
|
||||||
def get_action(self,state):
|
def get_action(self,state):
|
||||||
# Choose an action
|
# Choose an action
|
||||||
|
@ -122,13 +131,16 @@ class QTable:
|
||||||
#action = random.choice(options)
|
#action = random.choice(options)
|
||||||
action=np.argmax(self.qtable[state])
|
action=np.argmax(self.qtable[state])
|
||||||
return(action)
|
return(action)
|
||||||
|
|
||||||
|
def get_random_action(self):
|
||||||
|
return(random.choice((0,1,2,3)))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Perform learning
|
# Perform learning
|
||||||
width,height=10,10
|
width,height=50,30
|
||||||
perf=0
|
perf=0
|
||||||
last_state=None
|
last_state=None
|
||||||
last_action=None
|
last_action=None
|
||||||
|
@ -138,25 +150,34 @@ while True:
|
||||||
result=0
|
result=0
|
||||||
stuck=0
|
stuck=0
|
||||||
stuck_tolerance=1
|
stuck_tolerance=1
|
||||||
|
stuck_count=0
|
||||||
state=qtable.get_state(game)
|
state=qtable.get_state(game)
|
||||||
while result >= 0:
|
while result >= 0:
|
||||||
action=qtable.get_action(state)
|
action=qtable.get_action(state)
|
||||||
result=game.play3(action)
|
result=game.play3(action)
|
||||||
new_state=qtable.get_state(game)
|
new_state=qtable.get_state(game)
|
||||||
|
|
||||||
# Agent is stuck
|
|
||||||
if stuck>=(game.grid_width*game.grid_height)/stuck_tolerance:
|
|
||||||
game.new_game()
|
|
||||||
break
|
|
||||||
|
|
||||||
# Compute reward and update stuck
|
# Compute reward and update stuck
|
||||||
reward=0
|
reward=0
|
||||||
if result==-1:
|
if result==-1:
|
||||||
reward=-10
|
reward=-10
|
||||||
stuck=0
|
stuck=0
|
||||||
|
stuck_count=0
|
||||||
elif result==1:
|
elif result==1:
|
||||||
reward=50
|
reward=1
|
||||||
stuck=0
|
stuck=0
|
||||||
|
stuck_count=0
|
||||||
|
|
||||||
|
# Agent is stuck
|
||||||
|
if stuck>=(game.grid_width*game.grid_height)/stuck_tolerance:
|
||||||
|
stuck=0
|
||||||
|
stuck_count+=1
|
||||||
|
action=qtable.get_random_action()
|
||||||
|
print("Stuck!")
|
||||||
|
if stuck_count>2:
|
||||||
|
stuck_count=0
|
||||||
|
game.new_game()
|
||||||
|
break
|
||||||
|
|
||||||
# Apply learning
|
# Apply learning
|
||||||
qtable.apply_bellman(state,action,new_state,reward)
|
qtable.apply_bellman(state,action,new_state,reward)
|
||||||
|
|
6
snake.py
6
snake.py
|
@ -215,4 +215,8 @@ class Snake:
|
||||||
|
|
||||||
if self.play(self.direction,handle_quit=False) <0:
|
if self.play(self.direction,handle_quit=False) <0:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
game=Snake(length=50)
|
||||||
|
game.play_with_keyboard()
|
Loading…
Add table
Reference in a new issue