From 657578547f60673e227a5e97833b137083e71c2d Mon Sep 17 00:00:00 2001 From: Loic Guegan Date: Wed, 2 Nov 2022 12:18:07 +0100 Subject: [PATCH] Minor changes --- qlearning.py | 47 ++++++++++++++++++++++++++++++++++------------- snake.py | 6 +++++- 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/qlearning.py b/qlearning.py index ee0b0a9..3c8111d 100755 --- a/qlearning.py +++ b/qlearning.py @@ -27,15 +27,16 @@ class QTable: ##### Reward +1 when eat an apple ##### Reward -10 when hit obstacle """ - def __init__(self, file, save_every=10): + def __init__(self, file, save_every=5000): self.file=file self.save_every=save_every - self.update_counter=0 + self.save_counter=0 if os.path.exists(file): self.qtable=np.loadtxt(file) else: self.qtable=np.zeros((2**13, 4)) - + with open(file+"_generation","w") as f: + f.write("0") def isWall(self,h,game): if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height: @@ -108,10 +109,18 @@ class QTable: alpha=0.5 gamma=0.9 self.qtable[state,action]=self.qtable[state,action]+alpha*(reward+gamma*np.max(self.qtable[new_state])-self.qtable[state,action]) - self.update_counter+=1 - if self.update_counter>=self.save_every: + self.save_counter+=1 + if self.save_counter>=self.save_every: np.savetxt(self.file,self.qtable) - self.update_counter=0 + if os.path.exists(self.file+"_generation"): + generation=0 + with open(self.file+"_generation","r") as f: + generation=int(f.readline().rstrip()) + generation+=self.save_every + with open(self.file+"_generation","w") as f: + f.write(str(generation)) + print("Checkpointing generation "+str(generation)) + self.save_counter=0 def get_action(self,state): # Choose an action @@ -122,13 +131,16 @@ class QTable: #action = random.choice(options) action=np.argmax(self.qtable[state]) return(action) + + def get_random_action(self): + return(random.choice((0,1,2,3))) # Perform learning -width,height=10,10 +width,height=50,30 perf=0 last_state=None last_action=None @@ -138,25 +150,34 @@ while True: result=0 stuck=0 stuck_tolerance=1 + stuck_count=0 state=qtable.get_state(game) while result >= 0: action=qtable.get_action(state) result=game.play3(action) new_state=qtable.get_state(game) - # Agent is stuck - if stuck>=(game.grid_width*game.grid_height)/stuck_tolerance: - game.new_game() - break - # Compute reward and update stuck reward=0 if result==-1: reward=-10 stuck=0 + stuck_count=0 elif result==1: - reward=50 + reward=1 stuck=0 + stuck_count=0 + + # Agent is stuck + if stuck>=(game.grid_width*game.grid_height)/stuck_tolerance: + stuck=0 + stuck_count+=1 + action=qtable.get_random_action() + print("Stuck!") + if stuck_count>2: + stuck_count=0 + game.new_game() + break # Apply learning qtable.apply_bellman(state,action,new_state,reward) diff --git a/snake.py b/snake.py index 7526731..330ba12 100755 --- a/snake.py +++ b/snake.py @@ -215,4 +215,8 @@ class Snake: if self.play(self.direction,handle_quit=False) <0: break - \ No newline at end of file + + +if __name__ == "__main__": + game=Snake(length=50) + game.play_with_keyboard() \ No newline at end of file