Minor changes

This commit is contained in:
Loic Guegan 2022-11-02 12:18:07 +01:00
parent 65280331d3
commit 657578547f
2 changed files with 39 additions and 14 deletions

View file

@ -27,15 +27,16 @@ class QTable:
##### Reward +1 when eat an apple ##### Reward +1 when eat an apple
##### Reward -10 when hit obstacle ##### Reward -10 when hit obstacle
""" """
def __init__(self, file, save_every=10): def __init__(self, file, save_every=5000):
self.file=file self.file=file
self.save_every=save_every self.save_every=save_every
self.update_counter=0 self.save_counter=0
if os.path.exists(file): if os.path.exists(file):
self.qtable=np.loadtxt(file) self.qtable=np.loadtxt(file)
else: else:
self.qtable=np.zeros((2**13, 4)) self.qtable=np.zeros((2**13, 4))
with open(file+"_generation","w") as f:
f.write("0")
def isWall(self,h,game): def isWall(self,h,game):
if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height: if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height:
@ -108,10 +109,18 @@ class QTable:
alpha=0.5 alpha=0.5
gamma=0.9 gamma=0.9
self.qtable[state,action]=self.qtable[state,action]+alpha*(reward+gamma*np.max(self.qtable[new_state])-self.qtable[state,action]) self.qtable[state,action]=self.qtable[state,action]+alpha*(reward+gamma*np.max(self.qtable[new_state])-self.qtable[state,action])
self.update_counter+=1 self.save_counter+=1
if self.update_counter>=self.save_every: if self.save_counter>=self.save_every:
np.savetxt(self.file,self.qtable) np.savetxt(self.file,self.qtable)
self.update_counter=0 if os.path.exists(self.file+"_generation"):
generation=0
with open(self.file+"_generation","r") as f:
generation=int(f.readline().rstrip())
generation+=self.save_every
with open(self.file+"_generation","w") as f:
f.write(str(generation))
print("Checkpointing generation "+str(generation))
self.save_counter=0
def get_action(self,state): def get_action(self,state):
# Choose an action # Choose an action
@ -123,12 +132,15 @@ class QTable:
action=np.argmax(self.qtable[state]) action=np.argmax(self.qtable[state])
return(action) return(action)
def get_random_action(self):
return(random.choice((0,1,2,3)))
# Perform learning # Perform learning
width,height=10,10 width,height=50,30
perf=0 perf=0
last_state=None last_state=None
last_action=None last_action=None
@ -138,25 +150,34 @@ while True:
result=0 result=0
stuck=0 stuck=0
stuck_tolerance=1 stuck_tolerance=1
stuck_count=0
state=qtable.get_state(game) state=qtable.get_state(game)
while result >= 0: while result >= 0:
action=qtable.get_action(state) action=qtable.get_action(state)
result=game.play3(action) result=game.play3(action)
new_state=qtable.get_state(game) new_state=qtable.get_state(game)
# Agent is stuck
if stuck>=(game.grid_width*game.grid_height)/stuck_tolerance:
game.new_game()
break
# Compute reward and update stuck # Compute reward and update stuck
reward=0 reward=0
if result==-1: if result==-1:
reward=-10 reward=-10
stuck=0 stuck=0
stuck_count=0
elif result==1: elif result==1:
reward=50 reward=1
stuck=0 stuck=0
stuck_count=0
# Agent is stuck
if stuck>=(game.grid_width*game.grid_height)/stuck_tolerance:
stuck=0
stuck_count+=1
action=qtable.get_random_action()
print("Stuck!")
if stuck_count>2:
stuck_count=0
game.new_game()
break
# Apply learning # Apply learning
qtable.apply_bellman(state,action,new_state,reward) qtable.apply_bellman(state,action,new_state,reward)

View file

@ -216,3 +216,7 @@ class Snake:
if self.play(self.direction,handle_quit=False) <0: if self.play(self.direction,handle_quit=False) <0:
break break
if __name__ == "__main__":
game=Snake(length=50)
game.play_with_keyboard()